Skip to content
Open
Empty file.
2 changes: 2 additions & 0 deletions db/migrations/20260610143041_add-agent-message-images.up.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
ALTER TABLE agent_session_messages
ADD COLUMN images JSONB NOT NULL DEFAULT '[]'::jsonb;
3 changes: 2 additions & 1 deletion db/structure.sql
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,8 @@ CREATE TABLE public.agent_session_messages (
tool_call_id text DEFAULT ''::text NOT NULL,
tool_name text DEFAULT ''::text NOT NULL,
tool_status character varying(20) DEFAULT ''::character varying NOT NULL,
created_at timestamp with time zone DEFAULT now() NOT NULL
created_at timestamp with time zone DEFAULT now() NOT NULL,
images jsonb DEFAULT '[]'::jsonb NOT NULL
);


Expand Down
29 changes: 29 additions & 0 deletions pkg/agents/anthropic/anthropic_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,35 @@ func TestSendMessage_NoPreamble(t *testing.T) {
assert.Equal(t, "hi", text)
}

func TestSendMessage_IncludesImageContentBlocks(t *testing.T) {
var capturedBody map[string]any
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
body, _ := io.ReadAll(r.Body)
_ = json.Unmarshal(body, &capturedBody)
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("{}"))
}))
defer server.Close()

p := newTestProvider(t, server)
err := p.SendMessage(context.Background(), "sesn_abc", "look at this", agents.SendMessageOptions{
Images: []agents.MessageImage{{MediaType: "image/png", Data: "aGVsbG8="}},
})
require.NoError(t, err)

events := capturedBody["events"].([]any)
content := events[0].(map[string]any)["content"].([]any)
require.Len(t, content, 2)
assert.Equal(t, "look at this", content[0].(map[string]any)["text"].(string))

imageBlock := content[1].(map[string]any)
assert.Equal(t, "image", imageBlock["type"].(string))
source := imageBlock["source"].(map[string]any)
assert.Equal(t, "base64", source["type"].(string))
assert.Equal(t, "image/png", source["media_type"].(string))
assert.Equal(t, "aGVsbG8=", source["data"].(string))
}

func TestSendMessage_RequiresSessionID(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
t.Fatalf("server should not be hit")
Expand Down
16 changes: 15 additions & 1 deletion pkg/agents/anthropic/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,11 +96,25 @@ func (p *Provider) SendMessage(ctx context.Context, providerSessionID, message s
return fmt.Errorf("anthropic: provider session id is required")
}

content := []map[string]any{
{"type": "text", "text": withPreamble(message, opts.ContextPreamble)},
}
for _, image := range opts.Images {
content = append(content, map[string]any{
"type": "image",
"source": map[string]string{
"type": "base64",
"media_type": image.MediaType,
"data": image.Data,
},
})
}

body := map[string]any{
"events": []map[string]any{
{
"type": "user.message",
"content": []map[string]string{{"type": "text", "text": withPreamble(message, opts.ContextPreamble)}},
"content": content,
},
},
}
Expand Down
10 changes: 10 additions & 0 deletions pkg/agents/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -138,11 +138,21 @@ type CreateSessionResult struct {
ProviderSessionID string
}

// MessageImage is a base64-encoded image attached to a user message.
type MessageImage struct {
// MediaType is the IANA media type, e.g. "image/png".
MediaType string
// Data is the base64-encoded image bytes, without a data URI prefix.
Data string
}

// SendMessageOptions.ContextPreamble is prepended to the user's message so
// providers that need caller context inline (e.g. a CLI token on first turn)
// receive it without a separate system message.
type SendMessageOptions struct {
ContextPreamble string
// Images are attachments sent to the agent alongside the text content.
Images []MessageImage
}

type Provider interface {
Expand Down
18 changes: 14 additions & 4 deletions pkg/agents/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (
"github.com/superplanehq/superplane/pkg/grpc/actions/messages"
"github.com/superplanehq/superplane/pkg/jwt"
"github.com/superplanehq/superplane/pkg/models"
"gorm.io/datatypes"
"gorm.io/gorm"
)

Expand Down Expand Up @@ -237,8 +238,8 @@ func (s *Service) DefineOutcome(ctx context.Context, organizationID, userID, ses
return nil
}

func (s *Service) SendMessage(ctx context.Context, organizationID, userID, sessionID uuid.UUID, content string, mode ...string) (*models.AgentSessionMessage, error) {
if content == "" {
func (s *Service) SendMessage(ctx context.Context, organizationID, userID, sessionID uuid.UUID, content string, images []MessageImage, mode ...string) (*models.AgentSessionMessage, error) {
if content == "" && len(images) == 0 {
return nil, fmt.Errorf("message content is required")
}

Expand All @@ -257,7 +258,7 @@ func (s *Service) SendMessage(ctx context.Context, organizationID, userID, sessi
return nil, fmt.Errorf("build preamble: %w", err)
}

if err := s.provider.SendMessage(ctx, session.ProviderSessionID, content, SendMessageOptions{ContextPreamble: preamble}); err != nil {
if err := s.provider.SendMessage(ctx, session.ProviderSessionID, content, SendMessageOptions{ContextPreamble: preamble, Images: images}); err != nil {
if errors.Is(err, ErrSessionBusy) {
return nil, s.handleBusySession(sessionID, organizationID, userID)
}
Expand All @@ -269,7 +270,7 @@ func (s *Service) SendMessage(ctx context.Context, organizationID, userID, sessi
}
return nil, recoverErr
}
if err := s.provider.SendMessage(ctx, recovered.ProviderSessionID, content, SendMessageOptions{ContextPreamble: preamble}); err != nil {
if err := s.provider.SendMessage(ctx, recovered.ProviderSessionID, content, SendMessageOptions{ContextPreamble: preamble, Images: images}); err != nil {
if errors.Is(err, ErrSessionBusy) {
return nil, s.handleBusySession(sessionID, organizationID, userID)
}
Expand All @@ -289,6 +290,7 @@ func (s *Service) SendMessage(ctx context.Context, organizationID, userID, sessi
SessionID: sessionID,
Role: messageRole,
Content: content,
Images: toSessionImages(images),
}
if err := models.AppendAgentSessionMessage(persisted); err != nil {
return nil, fmt.Errorf("persist user message: %w", err)
Expand All @@ -304,6 +306,14 @@ func (s *Service) SendMessage(ctx context.Context, organizationID, userID, sessi
return persisted, nil
}

func toSessionImages(images []MessageImage) datatypes.JSONSlice[models.AgentSessionImage] {
out := make(datatypes.JSONSlice[models.AgentSessionImage], 0, len(images))
for _, image := range images {
out = append(out, models.AgentSessionImage{MediaType: image.MediaType, Data: image.Data})
}
return out
}

func (s *Service) handleBusySession(sessionID, organizationID, userID uuid.UUID) error {
if err := s.enqueueStreamAfterBusySession(sessionID, organizationID, userID); err != nil {
return err
Expand Down
56 changes: 42 additions & 14 deletions pkg/agents/service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ type fakeProvider struct {
sentSessions []string
defineSessions []string
lastPreamble string
lastImages []agents.MessageImage
lastOutcomeOpts agents.DefineOutcomeOptions
createSessionErr error
createHook func() error
Expand Down Expand Up @@ -90,6 +91,7 @@ func (f *fakeProvider) SendMessage(_ context.Context, providerSessionID string,
f.sendCalled++
f.sentSessions = append(f.sentSessions, providerSessionID)
f.lastPreamble = opts.ContextPreamble
f.lastImages = opts.Images
if len(f.sendErrs) > 0 {
err := f.sendErrs[0]
f.sendErrs = f.sendErrs[1:]
Expand Down Expand Up @@ -274,13 +276,39 @@ func TestService_SendMessage_ReturnsPersistedUserMessage(t *testing.T) {
session, err := svc.EnsureSession(context.Background(), r.Organization.ID, r.User, canvas.ID)
require.NoError(t, err)

persisted, err := svc.SendMessage(context.Background(), r.Organization.ID, r.User, session.ID, "hello")
persisted, err := svc.SendMessage(context.Background(), r.Organization.ID, r.User, session.ID, "hello", nil)
require.NoError(t, err)
require.NotNil(t, persisted)
require.NotEqual(t, uuid.Nil, persisted.ID)
assert.Equal(t, "hello", persisted.Content)
}

func TestService_SendMessage_ForwardsAndPersistsImages(t *testing.T) {
r := support.Setup(t)
defer r.Close()

canvas := setupCanvasForUser(t, r)
provider := &fakeProvider{}
svc := newService(t, r, provider)

session, err := svc.EnsureSession(context.Background(), r.Organization.ID, r.User, canvas.ID)
require.NoError(t, err)

images := []agents.MessageImage{{MediaType: "image/png", Data: "aGVsbG8="}}
persisted, err := svc.SendMessage(context.Background(), r.Organization.ID, r.User, session.ID, "", images)
require.NoError(t, err)
require.Len(t, provider.lastImages, 1)
assert.Equal(t, "image/png", provider.lastImages[0].MediaType)
require.Len(t, persisted.Images, 1)
assert.Equal(t, "aGVsbG8=", persisted.Images[0].Data)

stored, err := svc.ListMessages(session.ID, uuid.Nil, 10)
require.NoError(t, err)
require.Len(t, stored, 1)
require.Len(t, stored[0].Images, 1)
assert.Equal(t, "image/png", stored[0].Images[0].MediaType)
}

func TestService_SendMessage_AllowsFollowUpWhenSessionIsStreaming(t *testing.T) {
r := support.Setup(t)
defer r.Close()
Expand All @@ -293,7 +321,7 @@ func TestService_SendMessage_AllowsFollowUpWhenSessionIsStreaming(t *testing.T)
require.NoError(t, err)
require.NoError(t, models.UpdateAgentSessionStatus(session.ID, models.AgentSessionStatusStreaming))

persisted, err := svc.SendMessage(context.Background(), r.Organization.ID, r.User, session.ID, "hello")
persisted, err := svc.SendMessage(context.Background(), r.Organization.ID, r.User, session.ID, "hello", nil)
require.NoError(t, err)
require.NotNil(t, persisted)
assert.Equal(t, 1, provider.sendCalled)
Expand All @@ -314,7 +342,7 @@ func TestService_SendMessage_ProviderBusyKeepsSessionStreaming(t *testing.T) {
session, err := svc.EnsureSession(context.Background(), r.Organization.ID, r.User, canvas.ID)
require.NoError(t, err)

persisted, err := svc.SendMessage(context.Background(), r.Organization.ID, r.User, session.ID, "hello")
persisted, err := svc.SendMessage(context.Background(), r.Organization.ID, r.User, session.ID, "hello", nil)
require.ErrorIs(t, err, agents.ErrSessionBusy)
require.Nil(t, persisted)

Expand All @@ -337,7 +365,7 @@ func TestService_SendMessage_RecreatesUnavailableProviderSession(t *testing.T) {
require.NoError(t, err)
originalProviderSessionID := session.ProviderSessionID

persisted, err := svc.SendMessage(context.Background(), r.Organization.ID, r.User, session.ID, "hello")
persisted, err := svc.SendMessage(context.Background(), r.Organization.ID, r.User, session.ID, "hello", nil)
require.NoError(t, err)
require.NotNil(t, persisted)

Expand All @@ -363,7 +391,7 @@ func TestService_SendMessage_ReturnsBusyWhenRecoveredProviderSessionIsBusy(t *te
session, err := svc.EnsureSession(context.Background(), r.Organization.ID, r.User, canvas.ID)
require.NoError(t, err)

persisted, err := svc.SendMessage(context.Background(), r.Organization.ID, r.User, session.ID, "hello")
persisted, err := svc.SendMessage(context.Background(), r.Organization.ID, r.User, session.ID, "hello", nil)
require.ErrorIs(t, err, agents.ErrSessionBusy)
require.Nil(t, persisted)

Expand Down Expand Up @@ -398,7 +426,7 @@ func TestService_SendMessage_DoesNotHoldSessionLockWhileCreatingRecoveredProvide
require.NoError(t, err)
sessionID = session.ID

persisted, err := svc.SendMessage(context.Background(), r.Organization.ID, r.User, session.ID, "hello")
persisted, err := svc.SendMessage(context.Background(), r.Organization.ID, r.User, session.ID, "hello", nil)
require.NoError(t, err)
require.NotNil(t, persisted)
}
Expand Down Expand Up @@ -438,7 +466,7 @@ func TestService_SendMessage_RecoversFailedSession(t *testing.T) {
require.NoError(t, err)
require.NoError(t, models.UpdateAgentSessionStatus(session.ID, models.AgentSessionStatusFailed))

persisted, err := svc.SendMessage(context.Background(), r.Organization.ID, r.User, session.ID, "retry")
persisted, err := svc.SendMessage(context.Background(), r.Organization.ID, r.User, session.ID, "retry", nil)
require.NoError(t, err)
require.NotNil(t, persisted)
assert.Equal(t, 1, provider.sendCalled)
Expand All @@ -459,7 +487,7 @@ func TestService_SendMessage_RefreshesPreambleEveryTurn(t *testing.T) {
session, err := svc.EnsureSession(context.Background(), r.Organization.ID, r.User, canvas.ID)
require.NoError(t, err)

_, err = svc.SendMessage(context.Background(), r.Organization.ID, r.User, session.ID, "first")
_, err = svc.SendMessage(context.Background(), r.Organization.ID, r.User, session.ID, "first", nil)
require.NoError(t, err)
assert.Contains(t, provider.lastPreamble, canvas.ID.String())
assert.Contains(t, provider.lastPreamble, "api_token:")
Expand All @@ -475,7 +503,7 @@ func TestService_SendMessage_RefreshesPreambleEveryTurn(t *testing.T) {

require.NoError(t, models.UpdateAgentSessionStatus(session.ID, models.AgentSessionStatusIdle))
provider.lastPreamble = "<sentinel>"
_, err = svc.SendMessage(context.Background(), r.Organization.ID, r.User, session.ID, "second")
_, err = svc.SendMessage(context.Background(), r.Organization.ID, r.User, session.ID, "second", nil)
require.NoError(t, err)
assert.Contains(t, provider.lastPreamble, "api_token:",
"a fresh api_token must be re-injected on every turn so the session never expires mid-conversation")
Expand All @@ -492,12 +520,12 @@ func TestService_SendMessage_FirstTurnPreambleSurvivesProviderFailure(t *testing
session, err := svc.EnsureSession(context.Background(), r.Organization.ID, r.User, canvas.ID)
require.NoError(t, err)

_, err = svc.SendMessage(context.Background(), r.Organization.ID, r.User, session.ID, "first")
_, err = svc.SendMessage(context.Background(), r.Organization.ID, r.User, session.ID, "first", nil)
require.Error(t, err)

provider.sendErr = nil
provider.lastPreamble = "<sentinel>"
_, err = svc.SendMessage(context.Background(), r.Organization.ID, r.User, session.ID, "retry")
_, err = svc.SendMessage(context.Background(), r.Organization.ID, r.User, session.ID, "retry", nil)
require.NoError(t, err)
assert.Contains(t, provider.lastPreamble, "api_token:",
"preamble must still be injected after the previous attempt failed at the provider")
Expand Down Expand Up @@ -543,7 +571,7 @@ func TestService_SendMessage_PrivateToUser(t *testing.T) {
session, err := svc.EnsureSession(context.Background(), r.Organization.ID, r.User, canvas.ID)
require.NoError(t, err)

_, err = svc.SendMessage(context.Background(), r.Organization.ID, uuid.New(), session.ID, "intrusion")
_, err = svc.SendMessage(context.Background(), r.Organization.ID, uuid.New(), session.ID, "intrusion", nil)
require.Error(t, err)
assert.Equal(t, 0, provider.sendCalled)
}
Expand All @@ -559,7 +587,7 @@ func TestService_SendMessage_RejectsEmpty(t *testing.T) {
session, err := svc.EnsureSession(context.Background(), r.Organization.ID, r.User, canvas.ID)
require.NoError(t, err)

_, err = svc.SendMessage(context.Background(), r.Organization.ID, r.User, session.ID, "")
_, err = svc.SendMessage(context.Background(), r.Organization.ID, r.User, session.ID, "", nil)
require.Error(t, err)
assert.Equal(t, 0, provider.sendCalled)
}
Expand All @@ -575,7 +603,7 @@ func TestService_ListMessages_TailPagination(t *testing.T) {
session, err := svc.EnsureSession(context.Background(), r.Organization.ID, r.User, canvas.ID)
require.NoError(t, err)
for i := 0; i < 5; i++ {
_, err := svc.SendMessage(context.Background(), r.Organization.ID, r.User, session.ID, "m")
_, err := svc.SendMessage(context.Background(), r.Organization.ID, r.User, session.ID, "m", nil)
require.NoError(t, err)
require.NoError(t, models.UpdateAgentSessionStatus(session.ID, models.AgentSessionStatusIdle))
}
Expand Down
19 changes: 18 additions & 1 deletion pkg/grpc/actions/agents/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ type AgentsService interface {
EnsureSession(ctx context.Context, organizationID, userID, canvasID uuid.UUID) (*models.AgentSession, error)
GetSession(organizationID, userID, sessionID uuid.UUID) (*models.AgentSession, error)
ListMessages(sessionID, beforeID uuid.UUID, limit int) ([]models.AgentSessionMessage, error)
SendMessage(ctx context.Context, organizationID, userID, sessionID uuid.UUID, content string, mode ...string) (*models.AgentSessionMessage, error)
SendMessage(ctx context.Context, organizationID, userID, sessionID uuid.UUID, content string, images []agentservice.MessageImage, mode ...string) (*models.AgentSessionMessage, error)
InterruptSession(ctx context.Context, organizationID, userID, sessionID uuid.UUID) error
DefineOutcome(ctx context.Context, organizationID, userID, sessionID uuid.UUID, description, rubric string, maxIterations int) error
}
Expand Down Expand Up @@ -81,9 +81,26 @@ func serializeMessage(message *models.AgentSessionMessage) *pb.AgentChatMessage
ToolCallId: message.ToolCallID,
ToolName: message.ToolName,
ToolStatus: message.ToolStatus,
Images: serializeImages(message.Images),
}
if message.CreatedAt != nil {
out.CreatedAt = timestamppb.New(*message.CreatedAt)
}
return out
}

// serializeImages returns image metadata only. The base64 bytes are served
// out-of-band by the dedicated image endpoint (see handleAgentChatMessageImage)
// and intentionally omitted here, so a history page with several attachments
// stays under the gRPC/HTTP response size limits. The client builds the image
// URL from the message id and the image's position in this slice.
func serializeImages(images []models.AgentSessionImage) []*pb.AgentChatImage {
if len(images) == 0 {
return nil
}
out := make([]*pb.AgentChatImage, 0, len(images))
for _, image := range images {
out = append(out, &pb.AgentChatImage{MediaType: image.MediaType})
}
return out
Comment thread
cursor[bot] marked this conversation as resolved.
}
7 changes: 4 additions & 3 deletions pkg/grpc/actions/agents/helpers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"time"

"github.com/google/uuid"
agentservice "github.com/superplanehq/superplane/pkg/agents"
"github.com/superplanehq/superplane/pkg/models"
"github.com/superplanehq/superplane/test/support"
)
Expand All @@ -14,7 +15,7 @@ type stubService struct {
ensureSession func(context.Context, uuid.UUID, uuid.UUID, uuid.UUID) (*models.AgentSession, error)
getSession func(uuid.UUID, uuid.UUID, uuid.UUID) (*models.AgentSession, error)
listMessages func(uuid.UUID, uuid.UUID, int) ([]models.AgentSessionMessage, error)
sendMessage func(context.Context, uuid.UUID, uuid.UUID, uuid.UUID, string, string) (*models.AgentSessionMessage, error)
sendMessage func(context.Context, uuid.UUID, uuid.UUID, uuid.UUID, string, []agentservice.MessageImage, string) (*models.AgentSessionMessage, error)
interruptErr error
defineErr error
}
Expand All @@ -28,12 +29,12 @@ func (s *stubService) GetSession(o, u, id uuid.UUID) (*models.AgentSession, erro
func (s *stubService) ListMessages(id, before uuid.UUID, limit int) ([]models.AgentSessionMessage, error) {
return s.listMessages(id, before, limit)
}
func (s *stubService) SendMessage(ctx context.Context, o, u, id uuid.UUID, content string, mode ...string) (*models.AgentSessionMessage, error) {
func (s *stubService) SendMessage(ctx context.Context, o, u, id uuid.UUID, content string, images []agentservice.MessageImage, mode ...string) (*models.AgentSessionMessage, error) {
selectedMode := ""
if len(mode) > 0 {
selectedMode = mode[0]
}
return s.sendMessage(ctx, o, u, id, content, selectedMode)
return s.sendMessage(ctx, o, u, id, content, images, selectedMode)
}

func (s *stubService) InterruptSession(ctx context.Context, o, u, id uuid.UUID) error {
Expand Down
Loading