Skip to content
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions v2/pkg/engine/resolve/datasource.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,13 @@ type HookableSubscriptionDataSource interface {
// If an error is returned, the error is propagated to the client.
SubscriptionOnStart(ctx StartupHookContext, input []byte) (err error)
}

// HookablePubsubDatasource is an extension of HookableSubscriptionDataSource for pubsub datasources.
// They contain additional hooks which make sense for pubsub based datasources but not for normal
// subscription based datasources.
type HookablePubsubDatasource interface {
HookableSubscriptionDataSource
// SubscriptionOnCreate is called right before the trigger gets generated.
// It lets a source rewrite the subscription event configuration.
SubscriptionOnCreate(ctx context.Context, input []byte) (newInput []byte, err error)
}
16 changes: 16 additions & 0 deletions v2/pkg/engine/resolve/resolve.go
Original file line number Diff line number Diff line change
Expand Up @@ -1335,6 +1335,14 @@ func (r *Resolver) ResolveGraphQLSubscription(ctx *Context, subscription *GraphQ
return nil
}

if hook, ok := subscription.Trigger.Source.(HookablePubsubDatasource); ok {
input, err = hook.SubscriptionOnCreate(ctx.Context(), input)
if err != nil {
msg := []byte(`{"errors":[{"message":"failed to prepare subscription trigger"}]}`)
return writeFlushComplete(writer, msg)
}
}

headers, triggerID, err := r.prepareTrigger(ctx, subscription.Trigger.SourceName, input, subscription.Trigger.Source)
if err != nil {
msg := []byte(`{"errors":[{"message":"failed to prepare subscription trigger"}]}`)
Expand Down Expand Up @@ -1433,6 +1441,14 @@ func (r *Resolver) AsyncResolveGraphQLSubscription(ctx *Context, subscription *G
return err
}

if hook, ok := subscription.Trigger.Source.(HookablePubsubDatasource); ok {
input, err = hook.SubscriptionOnCreate(ctx.Context(), input)
if err != nil {
msg := []byte(`{"errors":[{"message":"failed to prepare subscription trigger"}]}`)
return writeFlushComplete(writer, msg)
}
}

headers, triggerID, err := r.prepareTrigger(ctx, subscription.Trigger.SourceName, input, subscription.Trigger.Source)
if err != nil {
msg := []byte(`{"errors":[{"message":"failed to prepare subscription trigger"}]}`)
Expand Down
335 changes: 335 additions & 0 deletions v2/pkg/engine/resolve/resolve_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5823,6 +5823,26 @@ func (f *_fakeStream) Start(ctx *Context, headers http.Header, input []byte, upd
return nil
}

type _fakePubsubStream struct {
*_fakeStream

subscriptionOnCreateFn func(ctx context.Context, input []byte) (newInput []byte, err error)
}

func (f *_fakePubsubStream) SubscriptionOnCreate(ctx context.Context, input []byte) (newInput []byte, err error) {
if f.subscriptionOnCreateFn == nil {
return input, nil
}
return f.subscriptionOnCreateFn(ctx, input)
}

func createFakePubsubStream(messageFunc messageFunc, delay time.Duration, onStart func(input []byte), subscriptionOnCreateFn func(ctx context.Context, input []byte) (newInput []byte, err error)) *_fakePubsubStream {
return &_fakePubsubStream{
_fakeStream: createFakeStream(messageFunc, delay, onStart, nil),
subscriptionOnCreateFn: subscriptionOnCreateFn,
}
}

func TestResolver_ResolveGraphQLSubscription(t *testing.T) {
defaultTimeout := time.Second * 30
if flags.IsWindows {
Expand Down Expand Up @@ -6844,6 +6864,321 @@ func TestResolver_ResolveGraphQLSubscription(t *testing.T) {
assert.Contains(t, errorMessage, "errors", "Expected error message in GraphQL format")
assert.Contains(t, errorMessage, expectedErr.Error(), "Expected actual error message to be included")
})

t.Run("should call SubscriptionOnCreate hook on pubsub datasource", func(t *testing.T) {
c := t.Context()

called := make(chan bool, 1)

fakeStream := createFakePubsubStream(func(counter int) (message string, done bool) {
return fmt.Sprintf(`{"data":{"counter":%d}}`, counter), counter == 0
}, 1*time.Millisecond, nil, func(ctx context.Context, input []byte) (newInput []byte, err error) {
called <- true
return input, nil
})

resolver, plan, recorder, id := setup(c, fakeStream)

ctx := &Context{
ctx: context.Background(),
}
Comment thread
coderabbitai[bot] marked this conversation as resolved.
Outdated

err := resolver.AsyncResolveGraphQLSubscription(ctx, plan, recorder, id)
assert.NoError(t, err)

select {
case <-called:
t.Log("SubscriptionOnCreate hook was called")
case <-time.After(defaultTimeout):
t.Fatal("SubscriptionOnCreate hook was not called")
}

recorder.AwaitComplete(t, defaultTimeout)
})

t.Run("SubscriptionOnCreate can rewrite the subscription input", func(t *testing.T) {
c := t.Context()

rewrittenInput := []byte(`{"method":"POST","url":"http://localhost:4000","body":{"query":"subscription { rewritten }"}}`)
startReceived := make(chan []byte, 1)

fakeStream := createFakePubsubStream(func(counter int) (message string, done bool) {
return fmt.Sprintf(`{"data":{"counter":%d}}`, counter), counter == 0
}, 1*time.Millisecond, func(input []byte) {
startReceived <- append([]byte(nil), input...)
}, func(ctx context.Context, input []byte) (newInput []byte, err error) {
return rewrittenInput, nil
})

resolver, plan, recorder, id := setup(c, fakeStream)

ctx := &Context{
ctx: context.Background(),
}

err := resolver.AsyncResolveGraphQLSubscription(ctx, plan, recorder, id)
assert.NoError(t, err)

select {
case got := <-startReceived:
assert.Equal(t, string(rewrittenInput), string(got), "Start should receive the rewritten input")
case <-time.After(defaultTimeout):
t.Fatal("Start was not called")
}

recorder.AwaitComplete(t, defaultTimeout)
})

t.Run("should propagate errors from SubscriptionOnCreate hook", func(t *testing.T) {
c := t.Context()

fakeStream := createFakePubsubStream(func(counter int) (message string, done bool) {
return fmt.Sprintf(`{"data":{"counter":%d}}`, counter), counter == 0
}, 1*time.Millisecond, nil, func(ctx context.Context, input []byte) (newInput []byte, err error) {
return nil, errors.New("create hook failed")
})

resolver, plan, recorder, id := setup(c, fakeStream)

ctx := &Context{
ctx: context.Background(),
}

err := resolver.AsyncResolveGraphQLSubscription(ctx, plan, recorder, id)
assert.NoError(t, err)

recorder.AwaitAnyMessageCount(t, defaultTimeout)
messages := recorder.Messages()
require.Greater(t, len(messages), 0)
assert.Contains(t, messages[0], "errors")
assert.Contains(t, messages[0], "failed to prepare subscription trigger")
})

t.Run("should call SubscriptionOnCreate hook on pubsub datasource (syncronous resolve)", func(t *testing.T) {
c := t.Context()

called := make(chan struct{}, 1)

fakeStream := createFakePubsubStream(func(counter int) (message string, done bool) {
return fmt.Sprintf(`{"data":{"counter":%d}}`, counter), true
}, 1*time.Millisecond, nil, func(ctx context.Context, input []byte) (newInput []byte, err error) {
called <- struct{}{}
return input, nil
})

resolver, plan, recorder, _ := setup(c, fakeStream)

ctx := NewContext(context.Background())

err := resolver.ResolveGraphQLSubscription(ctx, plan, recorder)
assert.NoError(t, err)

select {
case <-called:
t.Log("SubscriptionOnCreate hook was called")
default:
t.Fatal("SubscriptionOnCreate hook was not called")
}
})

t.Run("SubscriptionOnCreate can rewrite the subscription input (syncronous resolve)", func(t *testing.T) {
c := t.Context()

rewrittenInput := []byte(`{"method":"POST","url":"http://localhost:4000","body":{"query":"subscription { rewritten }"}}`)
var startInput []byte

fakeStream := createFakePubsubStream(func(counter int) (message string, done bool) {
return fmt.Sprintf(`{"data":{"counter":%d}}`, counter), true
}, 1*time.Millisecond, func(input []byte) {
startInput = append([]byte(nil), input...)
}, func(ctx context.Context, input []byte) (newInput []byte, err error) {
return rewrittenInput, nil
})

resolver, plan, recorder, _ := setup(c, fakeStream)

ctx := NewContext(context.Background())

err := resolver.ResolveGraphQLSubscription(ctx, plan, recorder)
assert.NoError(t, err)
assert.Equal(t, string(rewrittenInput), string(startInput), "Start should receive the rewritten input")
})

t.Run("should propagate errors from SubscriptionOnCreate hook (syncronous resolve)", func(t *testing.T) {
c := t.Context()

fakeStream := createFakePubsubStream(func(counter int) (message string, done bool) {
return fmt.Sprintf(`{"data":{"counter":%d}}`, counter), true
}, 1*time.Millisecond, nil, func(ctx context.Context, input []byte) (newInput []byte, err error) {
return nil, errors.New("create hook failed")
})

resolver, plan, recorder, _ := setup(c, fakeStream)

ctx := NewContext(context.Background())

err := resolver.ResolveGraphQLSubscription(ctx, plan, recorder)
assert.NoError(t, err)

messages := recorder.Messages()
require.Greater(t, len(messages), 0)
assert.Contains(t, messages[0], "errors")
assert.Contains(t, messages[0], "failed to prepare subscription trigger")
})

t.Run("it is possible to have two subscriptions to the same trigger with SubscriptionOnCreate", func(t *testing.T) {
c := t.Context()

createCallCount := atomic.Int32{}

// sub2Ready gates the data source goroutine so that it doesn't start
// emitting before sub2 has been registered on the trigger. Without this,
// the emitting goroutine's first triggerUpdate can race sub2's
// addSubscription on the unbuffered events channel, causing sub2 to
// miss counter=0.
sub2Ready := make(chan struct{})

fakeStream := createFakePubsubStream(func(counter int) (message string, done bool) {
return fmt.Sprintf(`{"data":{"counter":%d}}`, counter), counter == 100
}, 1*time.Millisecond, func(input []byte) {
<-sub2Ready
}, func(ctx context.Context, input []byte) (newInput []byte, err error) {
createCallCount.Add(1)
return input, nil
})

resolver1, plan1, recorder1, id1 := setup(c, fakeStream)
_, _, recorder2, id2 := setup(c, fakeStream)
id2.ConnectionID = id1.ConnectionID + 1
id2.SubscriptionID = id1.SubscriptionID + 1

ctx1 := &Context{ctx: context.Background()}
ctx2 := &Context{ctx: context.Background()}

err1 := resolver1.AsyncResolveGraphQLSubscription(ctx1, plan1, recorder1, id1)
assert.NoError(t, err1)

err2 := resolver1.AsyncResolveGraphQLSubscription(ctx2, plan1, recorder2, id2)
assert.NoError(t, err2)
close(sub2Ready)

recorder1.AwaitComplete(t, defaultTimeout)
require.Equal(t, 101, len(recorder1.Messages()))
assert.Equal(t, `{"data":{"counter":0}}`, recorder1.Messages()[0])
assert.Equal(t, `{"data":{"counter":100}}`, recorder1.Messages()[100])

recorder2.AwaitComplete(t, defaultTimeout)
require.Equal(t, 101, len(recorder2.Messages()))
assert.Equal(t, `{"data":{"counter":0}}`, recorder2.Messages()[0])
assert.Equal(t, `{"data":{"counter":100}}`, recorder2.Messages()[100])

assert.Equal(t, int32(2), createCallCount.Load(), "SubscriptionOnCreate should be called once per subscription")
})

t.Run("SubscriptionOnCreate merges two subscriptions with different inputs to one trigger", func(t *testing.T) {
c := t.Context()

createCallCount := atomic.Int32{}
startCallCount := atomic.Int32{}

// The canonical input that the hook normalizes both subscriptions to.
canonicalInput :=
[]byte(`{"method":"POST","url":"http://localhost:4000","body":{"query":"subscription { counter }"}}`)

// sub2Ready gates the data source goroutine so that it doesn't start
// emitting before sub2 has been registered on the shared trigger.
sub2Ready := make(chan struct{})

fakeStream := createFakePubsubStream(func(counter int) (message string, done bool) {
return fmt.Sprintf(`{"data":{"counter":%d}}`, counter), counter == 100
}, 1*time.Millisecond, func(input []byte) {
startCallCount.Add(1)
<-sub2Ready
}, func(ctx context.Context, input []byte) (newInput []byte, err error) {
createCallCount.Add(1)
// Normalize any input to the canonical form so both subscriptions
// end up on the same trigger regardless of their original inputs.
return canonicalInput, nil
})

resolver1, plan1, recorder1, id1 := setup(c, fakeStream)
_, plan2, recorder2, id2 := setup(c, fakeStream)
id2.ConnectionID = id1.ConnectionID + 1
id2.SubscriptionID = id1.SubscriptionID + 1

plan1.Trigger.InputTemplate.Segments[0].Data =
[]byte(`{"method":"POST","url":"http://localhost:4000","body":{"query":"subscription { variant_a }"}}`)
plan2.Trigger.InputTemplate.Segments[0].Data =
[]byte(`{"method":"POST","url":"http://localhost:4000","body":{"query":"subscription { variant_b }"}}`)

ctx1 := &Context{ctx: context.Background()}
ctx2 := &Context{ctx: context.Background()}

err1 := resolver1.AsyncResolveGraphQLSubscription(ctx1, plan1, recorder1, id1)
assert.NoError(t, err1)

err2 := resolver1.AsyncResolveGraphQLSubscription(ctx2, plan2, recorder2, id2)
assert.NoError(t, err2)
close(sub2Ready)

recorder1.AwaitComplete(t, defaultTimeout)
require.Equal(t, 101, len(recorder1.Messages()))
assert.Equal(t, `{"data":{"counter":0}}`, recorder1.Messages()[0])
assert.Equal(t, `{"data":{"counter":100}}`, recorder1.Messages()[100])

recorder2.AwaitComplete(t, defaultTimeout)
require.Equal(t, 101, len(recorder2.Messages()))
assert.Equal(t, `{"data":{"counter":0}}`, recorder2.Messages()[0])
assert.Equal(t, `{"data":{"counter":100}}`, recorder2.Messages()[100])

assert.Equal(t, int32(2), createCallCount.Load(), "SubscriptionOnCreate should be called once per subscription")
assert.Equal(t, int32(1), startCallCount.Load(),
"trigger should be started only once because both subscriptions share the same trigger after hook normalization")
})

t.Run("SubscriptionOnCreate failure for one subscription does not affect the other", func(t *testing.T) {
c := t.Context()

callCount := atomic.Int32{}

fakeStream := createFakePubsubStream(func(counter int) (message string, done bool) {
return fmt.Sprintf(`{"data":{"counter":%d}}`, counter), counter == 2
}, 1*time.Millisecond, nil, func(ctx context.Context, input []byte) (newInput []byte, err error) {
// Second call (sub B) fails; first call (sub A) succeeds.
if callCount.Add(1) == 2 {
return nil, errors.New("create hook failed for sub B")
}
return input, nil
})

resolver1, plan1, recorder1, id1 := setup(c, fakeStream)
_, _, recorder2, id2 := setup(c, fakeStream)
id2.ConnectionID = id1.ConnectionID + 1
id2.SubscriptionID = id1.SubscriptionID + 1

ctx1 := &Context{ctx: context.Background()}
ctx2 := &Context{ctx: context.Background()}

err1 := resolver1.AsyncResolveGraphQLSubscription(ctx1, plan1, recorder1, id1)
assert.NoError(t, err1)

// Sub B's SubscriptionOnCreate fails synchronously, so recorder2 is
// already complete with an error when AsyncResolveGraphQLSubscription returns.
err2 := resolver1.AsyncResolveGraphQLSubscription(ctx2, plan1, recorder2, id2)
assert.NoError(t, err2)

require.True(t, recorder2.complete.Load(), "recorder2 should be complete immediately after the hook error")
messages2 := recorder2.Messages()
require.Len(t, messages2, 1)
assert.Contains(t, messages2[0], "errors")
assert.Contains(t, messages2[0], "failed to prepare subscription trigger")

// Sub A should continue and complete normally, unaffected by sub B's failure.
recorder1.AwaitComplete(t, defaultTimeout)
require.Equal(t, 3, len(recorder1.Messages()))
assert.Equal(t, `{"data":{"counter":0}}`, recorder1.Messages()[0])
assert.Equal(t, `{"data":{"counter":2}}`, recorder1.Messages()[2])
})
}

func Test_ResolveGraphQLSubscriptionWithFilter(t *testing.T) {
Expand Down
Loading