diff --git a/db/migrations/20260610143041_add-agent-message-images.down.sql b/db/migrations/20260610143041_add-agent-message-images.down.sql new file mode 100644 index 0000000000..e69de29bb2 diff --git a/db/migrations/20260610143041_add-agent-message-images.up.sql b/db/migrations/20260610143041_add-agent-message-images.up.sql new file mode 100644 index 0000000000..9dac8947e6 --- /dev/null +++ b/db/migrations/20260610143041_add-agent-message-images.up.sql @@ -0,0 +1,2 @@ +ALTER TABLE agent_session_messages + ADD COLUMN images JSONB NOT NULL DEFAULT '[]'::jsonb; diff --git a/db/structure.sql b/db/structure.sql index 0a3770cb42..43823767ff 100644 --- a/db/structure.sql +++ b/db/structure.sql @@ -114,7 +114,8 @@ CREATE TABLE public.agent_session_messages ( tool_call_id text DEFAULT ''::text NOT NULL, tool_name text DEFAULT ''::text NOT NULL, tool_status character varying(20) DEFAULT ''::character varying NOT NULL, - created_at timestamp with time zone DEFAULT now() NOT NULL + created_at timestamp with time zone DEFAULT now() NOT NULL, + images jsonb DEFAULT '[]'::jsonb NOT NULL ); diff --git a/pkg/agents/anthropic/anthropic_test.go b/pkg/agents/anthropic/anthropic_test.go index 098c89b1f4..4597a89eee 100644 --- a/pkg/agents/anthropic/anthropic_test.go +++ b/pkg/agents/anthropic/anthropic_test.go @@ -352,6 +352,35 @@ func TestSendMessage_NoPreamble(t *testing.T) { assert.Equal(t, "hi", text) } +func TestSendMessage_IncludesImageContentBlocks(t *testing.T) { + var capturedBody map[string]any + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + body, _ := io.ReadAll(r.Body) + _ = json.Unmarshal(body, &capturedBody) + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("{}")) + })) + defer server.Close() + + p := newTestProvider(t, server) + err := p.SendMessage(context.Background(), "sesn_abc", "look at this", agents.SendMessageOptions{ + Images: []agents.MessageImage{{MediaType: "image/png", Data: "aGVsbG8="}}, + }) + require.NoError(t, err) + + events := capturedBody["events"].([]any) + content := events[0].(map[string]any)["content"].([]any) + require.Len(t, content, 2) + assert.Equal(t, "look at this", content[0].(map[string]any)["text"].(string)) + + imageBlock := content[1].(map[string]any) + assert.Equal(t, "image", imageBlock["type"].(string)) + source := imageBlock["source"].(map[string]any) + assert.Equal(t, "base64", source["type"].(string)) + assert.Equal(t, "image/png", source["media_type"].(string)) + assert.Equal(t, "aGVsbG8=", source["data"].(string)) +} + func TestSendMessage_RequiresSessionID(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { t.Fatalf("server should not be hit") diff --git a/pkg/agents/anthropic/provider.go b/pkg/agents/anthropic/provider.go index 77b9c9eb80..65a3cbfa00 100644 --- a/pkg/agents/anthropic/provider.go +++ b/pkg/agents/anthropic/provider.go @@ -96,11 +96,25 @@ func (p *Provider) SendMessage(ctx context.Context, providerSessionID, message s return fmt.Errorf("anthropic: provider session id is required") } + content := []map[string]any{ + {"type": "text", "text": withPreamble(message, opts.ContextPreamble)}, + } + for _, image := range opts.Images { + content = append(content, map[string]any{ + "type": "image", + "source": map[string]string{ + "type": "base64", + "media_type": image.MediaType, + "data": image.Data, + }, + }) + } + body := map[string]any{ "events": []map[string]any{ { "type": "user.message", - "content": []map[string]string{{"type": "text", "text": withPreamble(message, opts.ContextPreamble)}}, + "content": content, }, }, } diff --git a/pkg/agents/provider.go b/pkg/agents/provider.go index 607cebbe08..169b1ad328 100644 --- a/pkg/agents/provider.go +++ b/pkg/agents/provider.go @@ -138,11 +138,21 @@ type CreateSessionResult struct { ProviderSessionID string } +// MessageImage is a base64-encoded image attached to a user message. +type MessageImage struct { + // MediaType is the IANA media type, e.g. "image/png". + MediaType string + // Data is the base64-encoded image bytes, without a data URI prefix. + Data string +} + // SendMessageOptions.ContextPreamble is prepended to the user's message so // providers that need caller context inline (e.g. a CLI token on first turn) // receive it without a separate system message. type SendMessageOptions struct { ContextPreamble string + // Images are attachments sent to the agent alongside the text content. + Images []MessageImage } type Provider interface { diff --git a/pkg/agents/service.go b/pkg/agents/service.go index 7f4df46e27..4a4e767ed9 100644 --- a/pkg/agents/service.go +++ b/pkg/agents/service.go @@ -16,6 +16,7 @@ import ( "github.com/superplanehq/superplane/pkg/grpc/actions/messages" "github.com/superplanehq/superplane/pkg/jwt" "github.com/superplanehq/superplane/pkg/models" + "gorm.io/datatypes" "gorm.io/gorm" ) @@ -237,8 +238,8 @@ func (s *Service) DefineOutcome(ctx context.Context, organizationID, userID, ses return nil } -func (s *Service) SendMessage(ctx context.Context, organizationID, userID, sessionID uuid.UUID, content string, mode ...string) (*models.AgentSessionMessage, error) { - if content == "" { +func (s *Service) SendMessage(ctx context.Context, organizationID, userID, sessionID uuid.UUID, content string, images []MessageImage, mode ...string) (*models.AgentSessionMessage, error) { + if content == "" && len(images) == 0 { return nil, fmt.Errorf("message content is required") } @@ -257,7 +258,7 @@ func (s *Service) SendMessage(ctx context.Context, organizationID, userID, sessi return nil, fmt.Errorf("build preamble: %w", err) } - if err := s.provider.SendMessage(ctx, session.ProviderSessionID, content, SendMessageOptions{ContextPreamble: preamble}); err != nil { + if err := s.provider.SendMessage(ctx, session.ProviderSessionID, content, SendMessageOptions{ContextPreamble: preamble, Images: images}); err != nil { if errors.Is(err, ErrSessionBusy) { return nil, s.handleBusySession(sessionID, organizationID, userID) } @@ -269,7 +270,7 @@ func (s *Service) SendMessage(ctx context.Context, organizationID, userID, sessi } return nil, recoverErr } - if err := s.provider.SendMessage(ctx, recovered.ProviderSessionID, content, SendMessageOptions{ContextPreamble: preamble}); err != nil { + if err := s.provider.SendMessage(ctx, recovered.ProviderSessionID, content, SendMessageOptions{ContextPreamble: preamble, Images: images}); err != nil { if errors.Is(err, ErrSessionBusy) { return nil, s.handleBusySession(sessionID, organizationID, userID) } @@ -289,6 +290,7 @@ func (s *Service) SendMessage(ctx context.Context, organizationID, userID, sessi SessionID: sessionID, Role: messageRole, Content: content, + Images: toSessionImages(images), } if err := models.AppendAgentSessionMessage(persisted); err != nil { return nil, fmt.Errorf("persist user message: %w", err) @@ -304,6 +306,14 @@ func (s *Service) SendMessage(ctx context.Context, organizationID, userID, sessi return persisted, nil } +func toSessionImages(images []MessageImage) datatypes.JSONSlice[models.AgentSessionImage] { + out := make(datatypes.JSONSlice[models.AgentSessionImage], 0, len(images)) + for _, image := range images { + out = append(out, models.AgentSessionImage{MediaType: image.MediaType, Data: image.Data}) + } + return out +} + func (s *Service) handleBusySession(sessionID, organizationID, userID uuid.UUID) error { if err := s.enqueueStreamAfterBusySession(sessionID, organizationID, userID); err != nil { return err diff --git a/pkg/agents/service_test.go b/pkg/agents/service_test.go index 984ed7d044..8ede999ff7 100644 --- a/pkg/agents/service_test.go +++ b/pkg/agents/service_test.go @@ -57,6 +57,7 @@ type fakeProvider struct { sentSessions []string defineSessions []string lastPreamble string + lastImages []agents.MessageImage lastOutcomeOpts agents.DefineOutcomeOptions createSessionErr error createHook func() error @@ -90,6 +91,7 @@ func (f *fakeProvider) SendMessage(_ context.Context, providerSessionID string, f.sendCalled++ f.sentSessions = append(f.sentSessions, providerSessionID) f.lastPreamble = opts.ContextPreamble + f.lastImages = opts.Images if len(f.sendErrs) > 0 { err := f.sendErrs[0] f.sendErrs = f.sendErrs[1:] @@ -274,13 +276,39 @@ func TestService_SendMessage_ReturnsPersistedUserMessage(t *testing.T) { session, err := svc.EnsureSession(context.Background(), r.Organization.ID, r.User, canvas.ID) require.NoError(t, err) - persisted, err := svc.SendMessage(context.Background(), r.Organization.ID, r.User, session.ID, "hello") + persisted, err := svc.SendMessage(context.Background(), r.Organization.ID, r.User, session.ID, "hello", nil) require.NoError(t, err) require.NotNil(t, persisted) require.NotEqual(t, uuid.Nil, persisted.ID) assert.Equal(t, "hello", persisted.Content) } +func TestService_SendMessage_ForwardsAndPersistsImages(t *testing.T) { + r := support.Setup(t) + defer r.Close() + + canvas := setupCanvasForUser(t, r) + provider := &fakeProvider{} + svc := newService(t, r, provider) + + session, err := svc.EnsureSession(context.Background(), r.Organization.ID, r.User, canvas.ID) + require.NoError(t, err) + + images := []agents.MessageImage{{MediaType: "image/png", Data: "aGVsbG8="}} + persisted, err := svc.SendMessage(context.Background(), r.Organization.ID, r.User, session.ID, "", images) + require.NoError(t, err) + require.Len(t, provider.lastImages, 1) + assert.Equal(t, "image/png", provider.lastImages[0].MediaType) + require.Len(t, persisted.Images, 1) + assert.Equal(t, "aGVsbG8=", persisted.Images[0].Data) + + stored, err := svc.ListMessages(session.ID, uuid.Nil, 10) + require.NoError(t, err) + require.Len(t, stored, 1) + require.Len(t, stored[0].Images, 1) + assert.Equal(t, "image/png", stored[0].Images[0].MediaType) +} + func TestService_SendMessage_AllowsFollowUpWhenSessionIsStreaming(t *testing.T) { r := support.Setup(t) defer r.Close() @@ -293,7 +321,7 @@ func TestService_SendMessage_AllowsFollowUpWhenSessionIsStreaming(t *testing.T) require.NoError(t, err) require.NoError(t, models.UpdateAgentSessionStatus(session.ID, models.AgentSessionStatusStreaming)) - persisted, err := svc.SendMessage(context.Background(), r.Organization.ID, r.User, session.ID, "hello") + persisted, err := svc.SendMessage(context.Background(), r.Organization.ID, r.User, session.ID, "hello", nil) require.NoError(t, err) require.NotNil(t, persisted) assert.Equal(t, 1, provider.sendCalled) @@ -314,7 +342,7 @@ func TestService_SendMessage_ProviderBusyKeepsSessionStreaming(t *testing.T) { session, err := svc.EnsureSession(context.Background(), r.Organization.ID, r.User, canvas.ID) require.NoError(t, err) - persisted, err := svc.SendMessage(context.Background(), r.Organization.ID, r.User, session.ID, "hello") + persisted, err := svc.SendMessage(context.Background(), r.Organization.ID, r.User, session.ID, "hello", nil) require.ErrorIs(t, err, agents.ErrSessionBusy) require.Nil(t, persisted) @@ -337,7 +365,7 @@ func TestService_SendMessage_RecreatesUnavailableProviderSession(t *testing.T) { require.NoError(t, err) originalProviderSessionID := session.ProviderSessionID - persisted, err := svc.SendMessage(context.Background(), r.Organization.ID, r.User, session.ID, "hello") + persisted, err := svc.SendMessage(context.Background(), r.Organization.ID, r.User, session.ID, "hello", nil) require.NoError(t, err) require.NotNil(t, persisted) @@ -363,7 +391,7 @@ func TestService_SendMessage_ReturnsBusyWhenRecoveredProviderSessionIsBusy(t *te session, err := svc.EnsureSession(context.Background(), r.Organization.ID, r.User, canvas.ID) require.NoError(t, err) - persisted, err := svc.SendMessage(context.Background(), r.Organization.ID, r.User, session.ID, "hello") + persisted, err := svc.SendMessage(context.Background(), r.Organization.ID, r.User, session.ID, "hello", nil) require.ErrorIs(t, err, agents.ErrSessionBusy) require.Nil(t, persisted) @@ -398,7 +426,7 @@ func TestService_SendMessage_DoesNotHoldSessionLockWhileCreatingRecoveredProvide require.NoError(t, err) sessionID = session.ID - persisted, err := svc.SendMessage(context.Background(), r.Organization.ID, r.User, session.ID, "hello") + persisted, err := svc.SendMessage(context.Background(), r.Organization.ID, r.User, session.ID, "hello", nil) require.NoError(t, err) require.NotNil(t, persisted) } @@ -438,7 +466,7 @@ func TestService_SendMessage_RecoversFailedSession(t *testing.T) { require.NoError(t, err) require.NoError(t, models.UpdateAgentSessionStatus(session.ID, models.AgentSessionStatusFailed)) - persisted, err := svc.SendMessage(context.Background(), r.Organization.ID, r.User, session.ID, "retry") + persisted, err := svc.SendMessage(context.Background(), r.Organization.ID, r.User, session.ID, "retry", nil) require.NoError(t, err) require.NotNil(t, persisted) assert.Equal(t, 1, provider.sendCalled) @@ -459,7 +487,7 @@ func TestService_SendMessage_RefreshesPreambleEveryTurn(t *testing.T) { session, err := svc.EnsureSession(context.Background(), r.Organization.ID, r.User, canvas.ID) require.NoError(t, err) - _, err = svc.SendMessage(context.Background(), r.Organization.ID, r.User, session.ID, "first") + _, err = svc.SendMessage(context.Background(), r.Organization.ID, r.User, session.ID, "first", nil) require.NoError(t, err) assert.Contains(t, provider.lastPreamble, canvas.ID.String()) assert.Contains(t, provider.lastPreamble, "api_token:") @@ -475,7 +503,7 @@ func TestService_SendMessage_RefreshesPreambleEveryTurn(t *testing.T) { require.NoError(t, models.UpdateAgentSessionStatus(session.ID, models.AgentSessionStatusIdle)) provider.lastPreamble = "" - _, err = svc.SendMessage(context.Background(), r.Organization.ID, r.User, session.ID, "second") + _, err = svc.SendMessage(context.Background(), r.Organization.ID, r.User, session.ID, "second", nil) require.NoError(t, err) assert.Contains(t, provider.lastPreamble, "api_token:", "a fresh api_token must be re-injected on every turn so the session never expires mid-conversation") @@ -492,12 +520,12 @@ func TestService_SendMessage_FirstTurnPreambleSurvivesProviderFailure(t *testing session, err := svc.EnsureSession(context.Background(), r.Organization.ID, r.User, canvas.ID) require.NoError(t, err) - _, err = svc.SendMessage(context.Background(), r.Organization.ID, r.User, session.ID, "first") + _, err = svc.SendMessage(context.Background(), r.Organization.ID, r.User, session.ID, "first", nil) require.Error(t, err) provider.sendErr = nil provider.lastPreamble = "" - _, err = svc.SendMessage(context.Background(), r.Organization.ID, r.User, session.ID, "retry") + _, err = svc.SendMessage(context.Background(), r.Organization.ID, r.User, session.ID, "retry", nil) require.NoError(t, err) assert.Contains(t, provider.lastPreamble, "api_token:", "preamble must still be injected after the previous attempt failed at the provider") @@ -543,7 +571,7 @@ func TestService_SendMessage_PrivateToUser(t *testing.T) { session, err := svc.EnsureSession(context.Background(), r.Organization.ID, r.User, canvas.ID) require.NoError(t, err) - _, err = svc.SendMessage(context.Background(), r.Organization.ID, uuid.New(), session.ID, "intrusion") + _, err = svc.SendMessage(context.Background(), r.Organization.ID, uuid.New(), session.ID, "intrusion", nil) require.Error(t, err) assert.Equal(t, 0, provider.sendCalled) } @@ -559,7 +587,7 @@ func TestService_SendMessage_RejectsEmpty(t *testing.T) { session, err := svc.EnsureSession(context.Background(), r.Organization.ID, r.User, canvas.ID) require.NoError(t, err) - _, err = svc.SendMessage(context.Background(), r.Organization.ID, r.User, session.ID, "") + _, err = svc.SendMessage(context.Background(), r.Organization.ID, r.User, session.ID, "", nil) require.Error(t, err) assert.Equal(t, 0, provider.sendCalled) } @@ -575,7 +603,7 @@ func TestService_ListMessages_TailPagination(t *testing.T) { session, err := svc.EnsureSession(context.Background(), r.Organization.ID, r.User, canvas.ID) require.NoError(t, err) for i := 0; i < 5; i++ { - _, err := svc.SendMessage(context.Background(), r.Organization.ID, r.User, session.ID, "m") + _, err := svc.SendMessage(context.Background(), r.Organization.ID, r.User, session.ID, "m", nil) require.NoError(t, err) require.NoError(t, models.UpdateAgentSessionStatus(session.ID, models.AgentSessionStatusIdle)) } diff --git a/pkg/grpc/actions/agents/common.go b/pkg/grpc/actions/agents/common.go index e3dd4bd8d2..b8179b6789 100644 --- a/pkg/grpc/actions/agents/common.go +++ b/pkg/grpc/actions/agents/common.go @@ -19,7 +19,7 @@ type AgentsService interface { EnsureSession(ctx context.Context, organizationID, userID, canvasID uuid.UUID) (*models.AgentSession, error) GetSession(organizationID, userID, sessionID uuid.UUID) (*models.AgentSession, error) ListMessages(sessionID, beforeID uuid.UUID, limit int) ([]models.AgentSessionMessage, error) - SendMessage(ctx context.Context, organizationID, userID, sessionID uuid.UUID, content string, mode ...string) (*models.AgentSessionMessage, error) + SendMessage(ctx context.Context, organizationID, userID, sessionID uuid.UUID, content string, images []agentservice.MessageImage, mode ...string) (*models.AgentSessionMessage, error) InterruptSession(ctx context.Context, organizationID, userID, sessionID uuid.UUID) error DefineOutcome(ctx context.Context, organizationID, userID, sessionID uuid.UUID, description, rubric string, maxIterations int) error } @@ -81,9 +81,26 @@ func serializeMessage(message *models.AgentSessionMessage) *pb.AgentChatMessage ToolCallId: message.ToolCallID, ToolName: message.ToolName, ToolStatus: message.ToolStatus, + Images: serializeImages(message.Images), } if message.CreatedAt != nil { out.CreatedAt = timestamppb.New(*message.CreatedAt) } return out } + +// serializeImages returns image metadata only. The base64 bytes are served +// out-of-band by the dedicated image endpoint (see handleAgentChatMessageImage) +// and intentionally omitted here, so a history page with several attachments +// stays under the gRPC/HTTP response size limits. The client builds the image +// URL from the message id and the image's position in this slice. +func serializeImages(images []models.AgentSessionImage) []*pb.AgentChatImage { + if len(images) == 0 { + return nil + } + out := make([]*pb.AgentChatImage, 0, len(images)) + for _, image := range images { + out = append(out, &pb.AgentChatImage{MediaType: image.MediaType}) + } + return out +} diff --git a/pkg/grpc/actions/agents/helpers_test.go b/pkg/grpc/actions/agents/helpers_test.go index ca75562788..44eed5fce8 100644 --- a/pkg/grpc/actions/agents/helpers_test.go +++ b/pkg/grpc/actions/agents/helpers_test.go @@ -6,6 +6,7 @@ import ( "time" "github.com/google/uuid" + agentservice "github.com/superplanehq/superplane/pkg/agents" "github.com/superplanehq/superplane/pkg/models" "github.com/superplanehq/superplane/test/support" ) @@ -14,7 +15,7 @@ type stubService struct { ensureSession func(context.Context, uuid.UUID, uuid.UUID, uuid.UUID) (*models.AgentSession, error) getSession func(uuid.UUID, uuid.UUID, uuid.UUID) (*models.AgentSession, error) listMessages func(uuid.UUID, uuid.UUID, int) ([]models.AgentSessionMessage, error) - sendMessage func(context.Context, uuid.UUID, uuid.UUID, uuid.UUID, string, string) (*models.AgentSessionMessage, error) + sendMessage func(context.Context, uuid.UUID, uuid.UUID, uuid.UUID, string, []agentservice.MessageImage, string) (*models.AgentSessionMessage, error) interruptErr error defineErr error } @@ -28,12 +29,12 @@ func (s *stubService) GetSession(o, u, id uuid.UUID) (*models.AgentSession, erro func (s *stubService) ListMessages(id, before uuid.UUID, limit int) ([]models.AgentSessionMessage, error) { return s.listMessages(id, before, limit) } -func (s *stubService) SendMessage(ctx context.Context, o, u, id uuid.UUID, content string, mode ...string) (*models.AgentSessionMessage, error) { +func (s *stubService) SendMessage(ctx context.Context, o, u, id uuid.UUID, content string, images []agentservice.MessageImage, mode ...string) (*models.AgentSessionMessage, error) { selectedMode := "" if len(mode) > 0 { selectedMode = mode[0] } - return s.sendMessage(ctx, o, u, id, content, selectedMode) + return s.sendMessage(ctx, o, u, id, content, images, selectedMode) } func (s *stubService) InterruptSession(ctx context.Context, o, u, id uuid.UUID) error { diff --git a/pkg/grpc/actions/agents/send_agent_chat_message.go b/pkg/grpc/actions/agents/send_agent_chat_message.go index 3940886431..ad6883a328 100644 --- a/pkg/grpc/actions/agents/send_agent_chat_message.go +++ b/pkg/grpc/actions/agents/send_agent_chat_message.go @@ -2,7 +2,9 @@ package agents import ( "context" + "encoding/base64" "errors" + "slices" "github.com/google/uuid" log "github.com/sirupsen/logrus" @@ -13,6 +15,18 @@ import ( "gorm.io/gorm" ) +const ( + maxChatImages = 8 + + // maxChatImagePayloadBytes caps the combined decoded image bytes per message. + // Images are sent as base64 (~4/3 larger) in the protobuf body, so this stays + // well under the gRPC server's 4 MiB receive limit, leaving room for the + // message text and framing. + maxChatImagePayloadBytes = 2_500_000 +) + +var allowedChatImageMediaTypes = []string{"image/png", "image/jpeg", "image/gif", "image/webp"} + func SendAgentChatMessage(ctx context.Context, svc AgentsService, orgID, userID string, req *pb.SendAgentChatMessageRequest) (*pb.SendAgentChatMessageResponse, error) { org, user, err := parseOrgUser(orgID, userID) if err != nil { @@ -22,11 +36,15 @@ func SendAgentChatMessage(ctx context.Context, svc AgentsService, orgID, userID if err != nil { return nil, status.Error(codes.InvalidArgument, "invalid chat id") } - if req.Content == "" { - return nil, status.Error(codes.InvalidArgument, "content is required") + images, err := parseChatImages(req.Images) + if err != nil { + return nil, err + } + if req.Content == "" && len(images) == 0 { + return nil, status.Error(codes.InvalidArgument, "content or an image is required") } - persisted, err := svc.SendMessage(ctx, org, user, chatID, req.Content, agentModeFromProto(req.Mode)) + persisted, err := svc.SendMessage(ctx, org, user, chatID, req.Content, images, agentModeFromProto(req.Mode)) if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, status.Error(codes.NotFound, "agent chat not found") @@ -39,3 +57,33 @@ func SendAgentChatMessage(ctx context.Context, svc AgentsService, orgID, userID } return &pb.SendAgentChatMessageResponse{Message: serializeMessage(persisted)}, nil } + +func parseChatImages(images []*pb.AgentChatImage) ([]agentservice.MessageImage, error) { + if len(images) == 0 { + return nil, nil + } + if len(images) > maxChatImages { + return nil, status.Errorf(codes.InvalidArgument, "at most %d images are allowed per message", maxChatImages) + } + + out := make([]agentservice.MessageImage, 0, len(images)) + total := 0 + for _, image := range images { + if !slices.Contains(allowedChatImageMediaTypes, image.MediaType) { + return nil, status.Errorf(codes.InvalidArgument, "unsupported image media type: %q", image.MediaType) + } + decoded, err := base64.StdEncoding.DecodeString(image.Data) + if err != nil { + return nil, status.Error(codes.InvalidArgument, "image data must be valid base64") + } + if len(decoded) == 0 { + return nil, status.Error(codes.InvalidArgument, "image data is empty") + } + total += len(decoded) + if total > maxChatImagePayloadBytes { + return nil, status.Errorf(codes.InvalidArgument, "images exceed the %d byte limit per message", maxChatImagePayloadBytes) + } + out = append(out, agentservice.MessageImage{MediaType: image.MediaType, Data: image.Data}) + } + return out, nil +} diff --git a/pkg/grpc/actions/agents/send_agent_chat_message_test.go b/pkg/grpc/actions/agents/send_agent_chat_message_test.go index 2d128ed49d..efd7ba0cd1 100644 --- a/pkg/grpc/actions/agents/send_agent_chat_message_test.go +++ b/pkg/grpc/actions/agents/send_agent_chat_message_test.go @@ -1,7 +1,9 @@ package agents_test import ( + "bytes" "context" + "encoding/base64" "testing" "github.com/google/uuid" @@ -36,7 +38,7 @@ func TestSendAgentChatMessage_ProjectsSuccess(t *testing.T) { persistedID := uuid.New() svc := &stubService{ - sendMessage: func(_ context.Context, _, _, sid uuid.UUID, content string, mode string) (*models.AgentSessionMessage, error) { + sendMessage: func(_ context.Context, _, _, sid uuid.UUID, content string, _ []agentservice.MessageImage, mode string) (*models.AgentSessionMessage, error) { assert.Equal(t, chatID, sid) assert.Equal(t, "operator", mode) return &models.AgentSessionMessage{ @@ -60,7 +62,7 @@ func TestSendAgentChatMessage_TranslatesNotFound(t *testing.T) { r := support.Setup(t) defer r.Close() svc := &stubService{ - sendMessage: func(context.Context, uuid.UUID, uuid.UUID, uuid.UUID, string, string) (*models.AgentSessionMessage, error) { + sendMessage: func(context.Context, uuid.UUID, uuid.UUID, uuid.UUID, string, []agentservice.MessageImage, string) (*models.AgentSessionMessage, error) { return nil, gorm.ErrRecordNotFound }, } @@ -76,7 +78,7 @@ func TestSendAgentChatMessage_TranslatesBusySession(t *testing.T) { r := support.Setup(t) defer r.Close() svc := &stubService{ - sendMessage: func(context.Context, uuid.UUID, uuid.UUID, uuid.UUID, string, string) (*models.AgentSessionMessage, error) { + sendMessage: func(context.Context, uuid.UUID, uuid.UUID, uuid.UUID, string, []agentservice.MessageImage, string) (*models.AgentSessionMessage, error) { return nil, agentservice.ErrSessionBusy }, } @@ -93,7 +95,7 @@ func TestSendAgentChatMessage_MapsBuilderMode(t *testing.T) { defer r.Close() svc := &stubService{ - sendMessage: func(_ context.Context, _, _, _ uuid.UUID, _ string, mode string) (*models.AgentSessionMessage, error) { + sendMessage: func(_ context.Context, _, _, _ uuid.UUID, _ string, _ []agentservice.MessageImage, mode string) (*models.AgentSessionMessage, error) { assert.Equal(t, "builder", mode) return &models.AgentSessionMessage{ ID: uuid.New(), @@ -111,3 +113,78 @@ func TestSendAgentChatMessage_MapsBuilderMode(t *testing.T) { }) require.NoError(t, err) } + +func TestSendAgentChatMessage_ForwardsAndSerializesImages(t *testing.T) { + r := support.Setup(t) + defer r.Close() + + var forwarded []agentservice.MessageImage + svc := &stubService{ + sendMessage: func(_ context.Context, _, _, _ uuid.UUID, content string, images []agentservice.MessageImage, _ string) (*models.AgentSessionMessage, error) { + forwarded = images + return &models.AgentSessionMessage{ + ID: uuid.New(), + Role: models.AgentMessageRoleUser, + Content: content, + Images: []models.AgentSessionImage{{MediaType: "image/png", Data: "aGVsbG8="}}, + CreatedAt: now(), + }, nil + }, + } + + resp, err := actionsagents.SendAgentChatMessage(context.Background(), svc, r.Organization.ID.String(), r.User.String(), &pb.SendAgentChatMessageRequest{ + ChatId: uuid.NewString(), + Content: "", + Images: []*pb.AgentChatImage{{MediaType: "image/png", Data: "aGVsbG8="}}, + }) + require.NoError(t, err) + require.Len(t, forwarded, 1) + assert.Equal(t, "image/png", forwarded[0].MediaType) + require.Len(t, resp.Message.Images, 1) + assert.Equal(t, "image/png", resp.Message.Images[0].MediaType) + // Image bytes are served out-of-band, never embedded in the response. + assert.Empty(t, resp.Message.Images[0].Data) +} + +func TestSendAgentChatMessage_RejectsInvalidImages(t *testing.T) { + r := support.Setup(t) + defer r.Close() + svc := &stubService{} + + cases := []struct { + name string + image *pb.AgentChatImage + }{ + {"unsupported media type", &pb.AgentChatImage{MediaType: "image/tiff", Data: "aGVsbG8="}}, + {"invalid base64", &pb.AgentChatImage{MediaType: "image/png", Data: "not base64!!"}}, + {"empty data", &pb.AgentChatImage{MediaType: "image/png", Data: ""}}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + _, err := actionsagents.SendAgentChatMessage(context.Background(), svc, r.Organization.ID.String(), r.User.String(), &pb.SendAgentChatMessageRequest{ + ChatId: uuid.NewString(), + Images: []*pb.AgentChatImage{tc.image}, + }) + require.Error(t, err) + assert.Equal(t, codes.InvalidArgument, status.Code(err)) + }) + } +} + +func TestSendAgentChatMessage_RejectsImagesOverPayloadLimit(t *testing.T) { + r := support.Setup(t) + defer r.Close() + svc := &stubService{} + + // Two 2 MiB images decode to 4 MiB combined, over the per-message budget. + big := base64.StdEncoding.EncodeToString(bytes.Repeat([]byte{0}, 2*1024*1024)) + _, err := actionsagents.SendAgentChatMessage(context.Background(), svc, r.Organization.ID.String(), r.User.String(), &pb.SendAgentChatMessageRequest{ + ChatId: uuid.NewString(), + Images: []*pb.AgentChatImage{ + {MediaType: "image/png", Data: big}, + {MediaType: "image/png", Data: big}, + }, + }) + require.Error(t, err) + assert.Equal(t, codes.InvalidArgument, status.Code(err)) +} diff --git a/pkg/models/agent_session_message.go b/pkg/models/agent_session_message.go index 0117bc3a20..fd27af20f4 100644 --- a/pkg/models/agent_session_message.go +++ b/pkg/models/agent_session_message.go @@ -5,6 +5,7 @@ import ( "github.com/google/uuid" "github.com/superplanehq/superplane/pkg/database" + "gorm.io/datatypes" "gorm.io/gorm" ) @@ -19,6 +20,12 @@ const ( AgentToolStatusFailed = "failed" ) +// AgentSessionImage is a base64-encoded image attached to a user message. +type AgentSessionImage struct { + MediaType string `json:"media_type"` + Data string `json:"data"` +} + type AgentSessionMessage struct { ID uuid.UUID `gorm:"primaryKey;default:uuid_generate_v4()"` SessionID uuid.UUID @@ -28,6 +35,7 @@ type AgentSessionMessage struct { ToolCallID string ToolName string ToolStatus string + Images datatypes.JSONSlice[AgentSessionImage] CreatedAt *time.Time } @@ -44,6 +52,9 @@ func AppendAgentSessionMessageInTransaction(tx *gorm.DB, msg *AgentSessionMessag now := time.Now() msg.CreatedAt = &now } + if msg.Images == nil { + msg.Images = datatypes.JSONSlice[AgentSessionImage]{} + } if msg.ProviderEventID == "" { return tx.Create(msg).Error @@ -88,6 +99,18 @@ func AppendAgentSessionMessage(msg *AgentSessionMessage) error { return AppendAgentSessionMessageInTransaction(database.Conn(), msg) } +func FindAgentSessionMessageInTransaction(tx *gorm.DB, id uuid.UUID) (*AgentSessionMessage, error) { + var message AgentSessionMessage + if err := tx.Where("id = ?", id).First(&message).Error; err != nil { + return nil, err + } + return &message, nil +} + +func FindAgentSessionMessage(id uuid.UUID) (*AgentSessionMessage, error) { + return FindAgentSessionMessageInTransaction(database.Conn(), id) +} + // ListAgentSessionMessagesPage returns up to `limit` messages strictly older // than `before` (or the most recent `limit` when `before` is nil), in // chronological order (oldest-first). Used for tail-paginated chat scroll. diff --git a/pkg/public/agent_chat_message_image.go b/pkg/public/agent_chat_message_image.go new file mode 100644 index 0000000000..ab84343c79 --- /dev/null +++ b/pkg/public/agent_chat_message_image.go @@ -0,0 +1,76 @@ +package public + +import ( + "encoding/base64" + "net/http" + "strconv" + + "github.com/google/uuid" + "github.com/gorilla/mux" + log "github.com/sirupsen/logrus" + "github.com/superplanehq/superplane/pkg/models" + "github.com/superplanehq/superplane/pkg/public/middleware" +) + +// handleAgentChatMessageImage streams a single image attached to an agent chat +// message. Image bytes are served out-of-band (rather than embedded as base64 +// in the message list response) so chat history stays small enough to fit under +// the gRPC/HTTP response size limits. +func (s *Server) handleAgentChatMessageImage(w http.ResponseWriter, r *http.Request) { + vars := mux.Vars(r) + + sessionID, err := uuid.Parse(vars["chatId"]) + if err != nil { + http.Error(w, "invalid chat id", http.StatusBadRequest) + return + } + messageID, err := uuid.Parse(vars["messageId"]) + if err != nil { + http.Error(w, "invalid message id", http.StatusBadRequest) + return + } + index, err := strconv.Atoi(vars["index"]) + if err != nil || index < 0 { + http.Error(w, "invalid image index", http.StatusBadRequest) + return + } + + user, ok := middleware.GetUserFromContext(r.Context()) + if !ok { + http.Error(w, "Unauthorized", http.StatusUnauthorized) + return + } + + // Agent sessions are private to their creator, so this both authorizes the + // request and scopes the lookup to the caller. + if _, err := models.FindAgentSessionForUser(user.OrganizationID, user.ID, sessionID); err != nil { + http.Error(w, "agent chat not found", http.StatusNotFound) + return + } + + message, err := models.FindAgentSessionMessage(messageID) + if err != nil || message.SessionID != sessionID { + http.Error(w, "message not found", http.StatusNotFound) + return + } + + if index >= len(message.Images) { + http.Error(w, "image not found", http.StatusNotFound) + return + } + + image := message.Images[index] + data, err := base64.StdEncoding.DecodeString(image.Data) + if err != nil { + log.Errorf("failed to decode agent chat message image %s[%d]: %v", messageID, index, err) + http.Error(w, "invalid image data", http.StatusInternalServerError) + return + } + + w.Header().Set("X-Content-Type-Options", "nosniff") + w.Header().Set("Content-Type", image.MediaType) + w.Header().Set("Cache-Control", "private, max-age=86400, immutable") + if _, err := w.Write(data); err != nil { + log.Errorf("failed to write agent chat message image %s[%d]: %v", messageID, index, err) + } +} diff --git a/pkg/public/agent_chat_message_image_test.go b/pkg/public/agent_chat_message_image_test.go new file mode 100644 index 0000000000..f194817bd9 --- /dev/null +++ b/pkg/public/agent_chat_message_image_test.go @@ -0,0 +1,89 @@ +package public + +import ( + "encoding/base64" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/superplanehq/superplane/pkg/authentication" + "github.com/superplanehq/superplane/pkg/database" + "github.com/superplanehq/superplane/pkg/jwt" + "github.com/superplanehq/superplane/pkg/models" + "github.com/superplanehq/superplane/test/support" + "gorm.io/datatypes" +) + +func TestAgentChatMessageImage(t *testing.T) { + r := support.Setup(t) + defer r.Close() + + signer := jwt.NewSigner("test") + server, err := NewServer( + r.Encryptor, r.Registry, signer, support.NewOIDCProvider(), r.GitProvider, + "", "http://localhost", "http://localhost", "test", "/app/templates", r.AuthService, nil, false, + ) + require.NoError(t, err) + require.NoError(t, server.RegisterGRPCGateway("localhost:50051")) + + token, err := authentication.GenerateAccountToken(signer, r.Account.ID.String(), time.Now(), time.Hour) + require.NoError(t, err) + + canvas, _ := support.CreateCanvas(t, r.Organization.ID, r.User, nil, nil) + session := &models.AgentSession{ + OrganizationID: r.Organization.ID, + UserID: r.User, + CanvasID: canvas.ID, + Provider: "anthropic", + ProviderSessionID: "provider-session", + Status: models.AgentSessionStatusIdle, + } + require.NoError(t, models.CreateAgentSessionInTransaction(database.Conn(), session)) + + raw := []byte("pretend image bytes") + message := &models.AgentSessionMessage{ + SessionID: session.ID, + Role: models.AgentMessageRoleUser, + Content: "look at this", + Images: datatypes.JSONSlice[models.AgentSessionImage]{ + {MediaType: "image/png", Data: base64.StdEncoding.EncodeToString(raw)}, + }, + } + require.NoError(t, models.AppendAgentSessionMessage(message)) + + base := "/api/v1/agents/chats/" + session.ID.String() + "/messages/" + message.ID.String() + "/images/" + org := "?organization_id=" + r.Organization.ID.String() + + get := func(path string, withAuth bool) *httptest.ResponseRecorder { + req := httptest.NewRequest(http.MethodGet, path, nil) + if withAuth { + req.AddCookie(&http.Cookie{Name: "account_token", Value: token}) + } + rec := httptest.NewRecorder() + server.Router.ServeHTTP(rec, req) + return rec + } + + t.Run("serves the decoded image bytes", func(t *testing.T) { + res := get(base+"0"+org, true) + + require.Equal(t, http.StatusOK, res.Code) + assert.Equal(t, "image/png", res.Header().Get("Content-Type")) + assert.Equal(t, raw, res.Body.Bytes()) + }) + + t.Run("out-of-range index returns 404", func(t *testing.T) { + res := get(base+"9"+org, true) + + assert.Equal(t, http.StatusNotFound, res.Code) + }) + + t.Run("unauthenticated request is rejected", func(t *testing.T) { + res := get(base+"0"+org, false) + + assert.NotEqual(t, http.StatusOK, res.Code) + }) +} diff --git a/pkg/public/server.go b/pkg/public/server.go index 624c8bfcc9..1935a40d96 100644 --- a/pkg/public/server.go +++ b/pkg/public/server.go @@ -362,6 +362,13 @@ func (s *Server) RegisterGRPCGateway(grpcServerAddr string) error { orgAuthMiddleware(http.HandlerFunc(s.handleRepositoryFileDownload)), ).Methods(http.MethodGet) + // Registered before the /api/v1/agents gateway catch-all so the exact route + // matches first and serves image bytes out-of-band. + s.Router.Handle( + "/api/v1/agents/chats/{chatId}/messages/{messageId}/images/{index}", + orgAuthMiddleware(http.HandlerFunc(s.handleAgentChatMessageImage)), + ).Methods(http.MethodGet) + protectedGRPCHandler := orgAuthMiddleware(s.grpcGatewayHandler(grpcGatewayMux)) accountAuthMiddleware := middleware.AccountAuthMiddleware(s.jwt) diff --git a/protos/agents.proto b/protos/agents.proto index 537ad0e3fa..1c7d6906bb 100644 --- a/protos/agents.proto +++ b/protos/agents.proto @@ -114,6 +114,17 @@ message SendAgentChatMessageRequest { string content = 2; // Controls whether the agent should build, answer questions, or plan. AgentMode mode = 3; + // Images attached to the message (e.g. a pasted screenshot). Sent to the + // agent alongside the text so it can reason about them. + repeated AgentChatImage images = 4; +} + +// AgentChatImage is a base64-encoded image attached to a chat message. +message AgentChatImage { + // IANA media type, e.g. "image/png" or "image/jpeg". + string media_type = 1; + // Base64-encoded image bytes, without a data URI prefix. + string data = 2; } message SendAgentChatMessageResponse { @@ -152,4 +163,6 @@ message AgentChatMessage { string tool_name = 5; string tool_status = 6; google.protobuf.Timestamp created_at = 7; + // Images attached to the message by the user. + repeated AgentChatImage images = 8; } diff --git a/web_src/src/components/AgentSidebar/ChatComposer.tsx b/web_src/src/components/AgentSidebar/ChatComposer.tsx index 09166e7621..a1c0b16ea7 100644 --- a/web_src/src/components/AgentSidebar/ChatComposer.tsx +++ b/web_src/src/components/AgentSidebar/ChatComposer.tsx @@ -5,11 +5,14 @@ import { useMentions } from "./useMentions"; import { useMentionCandidates } from "./useMentionCandidates"; import { MentionDropdown } from "./MentionDropdown"; import { MentionTextarea } from "./MentionTextarea"; +import { ImageAttachmentPreviews } from "./ImageAttachmentPreviews"; +import { MAX_IMAGE_ATTACHMENTS, isSupportedImageFile, useImageAttachments } from "./useImageAttachments"; +import type { AgentOutgoingImage } from "@/components/CanvasToolSidebar/types"; import type { SuperplaneComponentsNode } from "@/api-client"; import type { CanvasesCanvasRun } from "@/api-client"; type ChatComposerProps = { - onSend: (content: string) => Promise; + onSend: (content: string, images: AgentOutgoingImage[]) => Promise; onStop: () => void; sending: boolean; sendPending: boolean; @@ -35,45 +38,104 @@ export function ChatComposer({ nodes, runs, }: ChatComposerProps) { + const c = useComposerController({ onSend, sendPending, nodes, runs }); + + return ( +
+
+ + + +
+ {c.showDropdown ? ( + + ) : null} +
+ ); +} + +type ComposerControllerArgs = { + onSend: (content: string, images: AgentOutgoingImage[]) => Promise; + sendPending: boolean; + nodes?: SuperplaneComponentsNode[]; + runs?: CanvasesCanvasRun[]; +}; + +function useComposerController({ onSend, sendPending, nodes, runs }: ComposerControllerArgs) { const textareaRef = useRef(null); const containerRef = useRef(null); const backdropRef = useRef(null); const mentionKeyboardRef = useRef<((e: React.KeyboardEvent) => boolean) | null>(null); - const { - value, - setValue, - showDropdown, - filter, - setCursorPos, - insertMention, - getMarkdown, - mentions, - isEmpty, - clear, - snapshot, - restore, - dismiss, - } = useMentions(); + const mentionsApi = useMentions(); + const { value, setValue, showDropdown, filter, setCursorPos, getMarkdown, mentions, isEmpty } = mentionsApi; + const { images, addFiles, removeImage, clear: clearImages } = useImageAttachments(); const candidates = useMentionCandidates(nodes, runs, filter, showDropdown); - const canSend = !isEmpty && !sendPending; + const hasImages = images.length > 0; + const canSend = (!isEmpty || hasImages) && !sendPending; + const canAttach = images.length < MAX_IMAGE_ATTACHMENTS; const handleSend = useCallback(async () => { - if (isEmpty) return; const content = getMarkdown().trim(); - if (!content) return; - snapshot(); - clear(); + if (!content && !hasImages) return; + const outgoingImages = images.map(({ mediaType, data }) => ({ mediaType, data })); + mentionsApi.snapshot(); + mentionsApi.clear(); try { - await onSend(content); + await onSend(content, outgoingImages); + clearImages(); } catch { - restore(); + mentionsApi.restore(); } - }, [isEmpty, getMarkdown, clear, onSend, snapshot, restore]); + }, [hasImages, images, getMarkdown, clearImages, onSend, mentionsApi]); + + const handlePaste = useCallback( + (e: React.ClipboardEvent) => { + const files = imageFilesFromClipboard(e); + if (files.length === 0) return; + // Only suppress the default paste when the clipboard is image-only, so a + // mixed text-and-image paste still inserts its text. + if (e.clipboardData.getData("text/plain").length === 0) e.preventDefault(); + void addFiles(files); + }, + [addFiles], + ); const handleMentionSelect = useCallback( (item: { type: "node" | "run"; id: string; label: string; meta?: string }) => { - const pos = insertMention(item); + const pos = mentionsApi.insertMention(item); requestAnimationFrame(() => { const ta = textareaRef.current; if (ta) { @@ -82,13 +144,13 @@ export function ChatComposer({ } }); }, - [insertMention], + [mentionsApi], ); const handleDismiss = useCallback(() => { - dismiss(); + mentionsApi.dismiss(); textareaRef.current?.focus(); - }, [dismiss]); + }, [mentionsApi]); const handleKeyDown = useCallback( (e: React.KeyboardEvent) => { @@ -106,46 +168,36 @@ export function ChatComposer({ void handleSend(); }); - return ( -
-
- - -
- {showDropdown ? ( - - ) : null} -
- ); + return { + textareaRef, + containerRef, + backdropRef, + mentionKeyboardRef, + value, + setValue, + setCursorPos, + mentions, + showDropdown, + candidates, + images, + addFiles, + removeImage, + canSend, + canAttach, + handleSend, + handlePaste, + handleMentionSelect, + handleDismiss, + handleKeyDown, + handleToolbarSend, + }; +} + +function imageFilesFromClipboard(e: React.ClipboardEvent): File[] { + return Array.from(e.clipboardData.items) + .filter((item) => item.kind === "file") + .map((item) => item.getAsFile()) + .filter((file): file is File => file !== null && isSupportedImageFile(file)); } function useStableCallback(callback: () => void): () => void { diff --git a/web_src/src/components/AgentSidebar/ComposerToolbar.tsx b/web_src/src/components/AgentSidebar/ComposerToolbar.tsx index 5c8b57fa5a..02d796ab85 100644 --- a/web_src/src/components/AgentSidebar/ComposerToolbar.tsx +++ b/web_src/src/components/AgentSidebar/ComposerToolbar.tsx @@ -1,7 +1,8 @@ -import { memo } from "react"; -import { ArrowUp, Loader2, Square } from "lucide-react"; +import { memo, useRef } from "react"; +import { ArrowUp, ImagePlus, Loader2, Square } from "lucide-react"; import { Button } from "@/components/ui/button"; import type { AgentMode } from "./agentMode"; +import { ALLOWED_IMAGE_TYPES } from "./useImageAttachments"; import { ModeToggle } from "./ModeToggle"; interface ComposerToolbarProps { @@ -12,8 +13,52 @@ interface ComposerToolbarProps { stopping?: boolean; statusLabel: string; canSend: boolean; + canAttach: boolean; onStop: () => void; onSend: () => void; + onAddFiles: (files: FileList | File[]) => void; +} + +function AttachImageButton({ + canAttach, + onAddFiles, +}: { + canAttach: boolean; + onAddFiles: (files: FileList | File[]) => void; +}) { + const fileInputRef = useRef(null); + + return ( + <> + { + if (event.target.files && event.target.files.length > 0) { + onAddFiles(event.target.files); + } + event.target.value = ""; + }} + /> + + + ); } export const ComposerToolbar = memo(function ComposerToolbar({ @@ -24,12 +69,17 @@ export const ComposerToolbar = memo(function ComposerToolbar({ stopping, statusLabel, canSend, + canAttach, onStop, onSend, + onAddFiles, }: ComposerToolbarProps) { return (
- +
+ + +
{statusLabel} {sending && ( diff --git a/web_src/src/components/AgentSidebar/ImageAttachmentPreviews.tsx b/web_src/src/components/AgentSidebar/ImageAttachmentPreviews.tsx new file mode 100644 index 0000000000..b53d1a3160 --- /dev/null +++ b/web_src/src/components/AgentSidebar/ImageAttachmentPreviews.tsx @@ -0,0 +1,32 @@ +import { X } from "lucide-react"; +import type { ComposerImage } from "./useImageAttachments"; + +interface ImageAttachmentPreviewsProps { + images: ComposerImage[]; + onRemove: (id: string) => void; +} + +export function ImageAttachmentPreviews({ images, onRemove }: ImageAttachmentPreviewsProps) { + if (images.length === 0) return null; + + return ( +
+ {images.map((image) => ( +
+ {image.name} + +
+ ))} +
+ ); +} diff --git a/web_src/src/components/AgentSidebar/MentionTextarea.tsx b/web_src/src/components/AgentSidebar/MentionTextarea.tsx index df70d10929..70a2ef481e 100644 --- a/web_src/src/components/AgentSidebar/MentionTextarea.tsx +++ b/web_src/src/components/AgentSidebar/MentionTextarea.tsx @@ -13,6 +13,7 @@ interface MentionTextareaProps { setValue: (v: string) => void; setCursorPos: (pos: number) => void; onKeyDown: (e: React.KeyboardEvent) => void; + onPaste?: (e: React.ClipboardEvent) => void; placeholder?: string; textareaRef: React.RefObject; backdropRef: React.RefObject; @@ -24,6 +25,7 @@ export function MentionTextarea({ setValue, setCursorPos, onKeyDown, + onPaste, placeholder, textareaRef, backdropRef, @@ -110,6 +112,7 @@ export function MentionTextarea({ onKeyUp={handleSelect} onClick={handleSelect} onScroll={handleScroll} + onPaste={onPaste} rows={1} placeholder={placeholder} data-testid="agent-input" diff --git a/web_src/src/components/AgentSidebar/useImageAttachments.ts b/web_src/src/components/AgentSidebar/useImageAttachments.ts new file mode 100644 index 0000000000..8c8275d31a --- /dev/null +++ b/web_src/src/components/AgentSidebar/useImageAttachments.ts @@ -0,0 +1,116 @@ +import { useCallback, useRef, useState } from "react"; +import { showErrorToast } from "@/lib/toast"; + +export const MAX_IMAGE_ATTACHMENTS = 8; +export const ALLOWED_IMAGE_TYPES = ["image/png", "image/jpeg", "image/gif", "image/webp"]; + +// Caps the combined raw image bytes per message. Images are sent as base64 +// (~4/3 larger) alongside the message text, so this stays well under the gRPC +// server's 4 MiB receive limit and mirrors maxChatImagePayloadBytes in +// pkg/grpc/actions/agents/send_agent_chat_message.go. Keeping it at or below the +// backend cap means oversized attachments are rejected client-side with a clear +// error instead of failing the request with an HTTP 429. +export const MAX_TOTAL_IMAGE_BYTES = 2_500_000; +// A single image may use the entire per-message budget. +export const MAX_IMAGE_BYTES = MAX_TOTAL_IMAGE_BYTES; + +export type ComposerImage = { + id: string; + name: string; + mediaType: string; + // Raw (decoded) byte size, used to enforce the per-message payload budget. + bytes: number; + // Full `data:;base64,` URL, used for inline previews. + dataUrl: string; + // Base64 payload only, without the data URI prefix, sent to the API. + data: string; +}; + +export type UseImageAttachmentsReturn = { + images: ComposerImage[]; + addFiles: (files: FileList | File[]) => Promise; + removeImage: (id: string) => void; + clear: () => void; +}; + +// isSupportedImageFile reports whether the composer handles a file at all. Size +// limits are enforced in addFiles (with user feedback), so callers can use this +// to decide whether to intercept a paste regardless of the file's size. +export function isSupportedImageFile(file: File): boolean { + return ALLOWED_IMAGE_TYPES.includes(file.type) && file.size > 0; +} + +export function useImageAttachments(): UseImageAttachmentsReturn { + const [images, setImages] = useState([]); + const imagesRef = useRef(images); + imagesRef.current = images; + + const addFiles = useCallback(async (files: FileList | File[]) => { + const candidates = Array.from(files).filter(isSupportedImageFile); + if (candidates.length === 0) return; + + const sized = candidates.filter((file) => file.size <= MAX_IMAGE_BYTES); + if (sized.length < candidates.length) { + showErrorToast(`Each image must be ${formatMegabytes(MAX_IMAGE_BYTES)} or smaller.`); + } + if (sized.length === 0) return; + + // A single unreadable file must not drop the rest of the batch. + const read = (await Promise.all(sized.map(readImage))).filter((image): image is ComposerImage => image !== null); + if (read.length === 0) return; + + const current = imagesRef.current; + const accepted: ComposerImage[] = []; + let count = current.length; + let total = current.reduce((sum, image) => sum + image.bytes, 0); + let rejected = false; + for (const image of read) { + if (count >= MAX_IMAGE_ATTACHMENTS || total + image.bytes > MAX_TOTAL_IMAGE_BYTES) { + rejected = true; + continue; + } + accepted.push(image); + count += 1; + total += image.bytes; + } + if (rejected) { + showErrorToast( + `Attachments are limited to ${MAX_IMAGE_ATTACHMENTS} images and ${formatMegabytes(MAX_TOTAL_IMAGE_BYTES)} per message.`, + ); + } + if (accepted.length > 0) { + setImages((previous) => [...previous, ...accepted]); + } + }, []); + + const removeImage = useCallback((id: string) => { + setImages((current) => current.filter((image) => image.id !== id)); + }, []); + + const clear = useCallback(() => setImages([]), []); + + return { images, addFiles, removeImage, clear }; +} + +function readImage(file: File): Promise { + return new Promise((resolve) => { + const reader = new FileReader(); + reader.onerror = () => resolve(null); + reader.onload = () => { + const dataUrl = String(reader.result); + resolve({ + id: crypto.randomUUID(), + name: file.name || "image", + mediaType: file.type, + bytes: file.size, + dataUrl, + data: dataUrl.slice(dataUrl.indexOf(",") + 1), + }); + }; + reader.readAsDataURL(file); + }); +} + +function formatMegabytes(bytes: number): string { + return `${(bytes / (1024 * 1024)).toFixed(1)} MB`; +} diff --git a/web_src/src/components/CanvasToolSidebar/AgentConversationTranscript.spec.tsx b/web_src/src/components/CanvasToolSidebar/AgentConversationTranscript.spec.tsx index 9fbd73b4fc..c94a6138d9 100644 --- a/web_src/src/components/CanvasToolSidebar/AgentConversationTranscript.spec.tsx +++ b/web_src/src/components/CanvasToolSidebar/AgentConversationTranscript.spec.tsx @@ -73,6 +73,29 @@ describe("ConversationTranscript command groups", () => { }); describe("ConversationTranscript user messages", () => { + it("renders attached images as linked thumbnails", () => { + const groups: MessageGroup[] = [ + { + type: "message", + message: { + id: "user-with-image", + role: "user", + content: "fix this", + toolName: "", + toolCallId: "", + toolStatus: "", + images: [{ mediaType: "image/png", url: "/api/v1/agents/chats/c-1/messages/user-with-image/images/0" }], + createdAt: null, + }, + }, + ]; + + render(); + + const image = screen.getByRole("img", { name: "attachment" }); + expect(image).toHaveAttribute("src", "/api/v1/agents/chats/c-1/messages/user-with-image/images/0"); + }); + it("keeps compact user messages sticky", () => { render(); diff --git a/web_src/src/components/CanvasToolSidebar/AgentConversationTranscript.tsx b/web_src/src/components/CanvasToolSidebar/AgentConversationTranscript.tsx index 99637dd946..d87120ad92 100644 --- a/web_src/src/components/CanvasToolSidebar/AgentConversationTranscript.tsx +++ b/web_src/src/components/CanvasToolSidebar/AgentConversationTranscript.tsx @@ -196,6 +196,7 @@ const MessageRow = memo(function MessageRow({ )} data-testid={isUser ? "agent-user-message" : "agent-assistant-message"} > + + {images.map((image, index) => ( + + attachment + + ))} +
+ ); +} + function shouldRenderMessage(message: AgentMessage): boolean { return message.role !== "system" && !(message.role === "user" && isSystemNotification(message.content)); } diff --git a/web_src/src/components/CanvasToolSidebar/AgentTabPanel.tsx b/web_src/src/components/CanvasToolSidebar/AgentTabPanel.tsx index 715a51308b..62b5b2fe21 100644 --- a/web_src/src/components/CanvasToolSidebar/AgentTabPanel.tsx +++ b/web_src/src/components/CanvasToolSidebar/AgentTabPanel.tsx @@ -34,7 +34,7 @@ import { useStoredOutcomeState, useThinkingIndicator, } from "./agentConversationState"; -import type { AgentMessage } from "./types"; +import type { AgentMessage, AgentOutgoingImage } from "./types"; import type { CanvasToolSidebarState } from "./useCanvasToolSidebarState"; import { groupMessages } from "./agentMessageGroups"; @@ -58,7 +58,7 @@ type DraftActionsBarProps = { }; type ConversationHandlers = { - handleSend: (content: string) => Promise; + handleSend: (content: string, images?: AgentOutgoingImage[]) => Promise; handleStop: () => void; handleQuickAction: (action: string) => Promise; handleStartBuilding: (rubric: { title: string; criteria: string[]; categories?: RubricCategory[] }) => Promise; @@ -312,11 +312,11 @@ function useConversationHandlers({ mutationsRef.current = { sendMutation, interruptMutation, outcomeMutation }; const handleSend = useCallback( - async (content: string) => { + async (content: string, images?: AgentOutgoingImage[]) => { const { sendMutation: send } = mutationsRef.current; - if (!content.trim() || send.isPending) return; + if ((!content.trim() && (images?.length ?? 0) === 0) || send.isPending) return; setError(null); - await send.mutateAsync({ chatId, content, mode: agentMode }).catch((error) => { + await send.mutateAsync({ chatId, content, mode: agentMode, images }).catch((error) => { setError(error instanceof Error ? error.message : "failed to send message"); throw error; }); @@ -407,7 +407,7 @@ function ComposerWithCanvasData({ }: { canvasId: string; organizationId: string; - onSend: (content: string) => Promise; + onSend: (content: string, images: AgentOutgoingImage[]) => Promise; onStop: () => void; sending: boolean; sendPending: boolean; diff --git a/web_src/src/components/CanvasToolSidebar/index.spec.tsx b/web_src/src/components/CanvasToolSidebar/index.spec.tsx index 0c7a056b3d..965ae646b0 100644 --- a/web_src/src/components/CanvasToolSidebar/index.spec.tsx +++ b/web_src/src/components/CanvasToolSidebar/index.spec.tsx @@ -111,7 +111,12 @@ describe("CanvasToolSidebar", () => { await user.type(screen.getByTestId("agent-input"), "retry"); await user.click(screen.getByTestId("agent-send-message-button")); - expect(sendMutation.mutateAsync).toHaveBeenCalledWith({ chatId: "chat-1", content: "retry", mode: "operator" }); + expect(sendMutation.mutateAsync).toHaveBeenCalledWith({ + chatId: "chat-1", + content: "retry", + mode: "operator", + images: [], + }); }); it("does not render when managed agents are disabled", () => { diff --git a/web_src/src/components/CanvasToolSidebar/types.spec.ts b/web_src/src/components/CanvasToolSidebar/types.spec.ts index 74032df62b..4f297e7d65 100644 --- a/web_src/src/components/CanvasToolSidebar/types.spec.ts +++ b/web_src/src/components/CanvasToolSidebar/types.spec.ts @@ -25,19 +25,23 @@ describe("fromApiChat", () => { describe("fromApiMessage", () => { it("returns null when id is missing", () => { - expect(fromApiMessage({ role: "user", content: "hi" })).toBeNull(); + expect(fromApiMessage({ role: "user", content: "hi" }, "chat-1", "org-1")).toBeNull(); }); it("preserves all populated fields", () => { - const msg = fromApiMessage({ - id: "msg-1", - role: "assistant", - content: "hello", - toolName: "search", - toolCallId: "call-1", - toolStatus: "started", - createdAt: "2026-05-13T00:00:00Z", - }); + const msg = fromApiMessage( + { + id: "msg-1", + role: "assistant", + content: "hello", + toolName: "search", + toolCallId: "call-1", + toolStatus: "started", + createdAt: "2026-05-13T00:00:00Z", + }, + "chat-1", + "org-1", + ); expect(msg).toEqual({ id: "msg-1", role: "assistant", @@ -45,7 +49,34 @@ describe("fromApiMessage", () => { toolName: "search", toolCallId: "call-1", toolStatus: "started", + images: [], createdAt: "2026-05-13T00:00:00Z", }); }); + + it("maps images to out-of-band URLs keyed by their original index", () => { + const msg = fromApiMessage( + { + id: "msg-2", + role: "user", + content: "look", + images: [{ mediaType: "image/png" }, {}, { mediaType: "image/jpeg" }], + }, + "chat-9", + "org-7", + ); + expect(msg?.images).toEqual([ + { mediaType: "image/png", url: "/api/v1/agents/chats/chat-9/messages/msg-2/images/0?organization_id=org-7" }, + { mediaType: "image/jpeg", url: "/api/v1/agents/chats/chat-9/messages/msg-2/images/2?organization_id=org-7" }, + ]); + }); + + it("omits the organization query when no org is provided", () => { + const msg = fromApiMessage( + { id: "msg-3", role: "user", content: "look", images: [{ mediaType: "image/png" }] }, + "chat-9", + undefined, + ); + expect(msg?.images?.[0].url).toBe("/api/v1/agents/chats/chat-9/messages/msg-3/images/0"); + }); }); diff --git a/web_src/src/components/CanvasToolSidebar/types.ts b/web_src/src/components/CanvasToolSidebar/types.ts index 973cb4885b..31b0f151fd 100644 --- a/web_src/src/components/CanvasToolSidebar/types.ts +++ b/web_src/src/components/CanvasToolSidebar/types.ts @@ -9,6 +9,19 @@ export type AgentChat = { updatedAt: string | null; }; +// Image attached to a stored message. Bytes are served out-of-band by the +// image endpoint, so the message carries a URL rather than inline base64. +export type AgentMessageImage = { + mediaType: string; + url: string; +}; + +// Image being composed/sent by the client, carrying the base64 payload. +export type AgentOutgoingImage = { + mediaType: string; + data: string; +}; + export type AgentMessage = { id: string; role: string; @@ -16,6 +29,7 @@ export type AgentMessage = { toolName: string; toolCallId: string; toolStatus: string; + images?: AgentMessageImage[]; createdAt: string | null; }; @@ -54,15 +68,38 @@ export function fromApiChat(input: AgentsAgentChatInfo | undefined): AgentChat | }; } -export function fromApiMessage(input: AgentsAgentChatMessage | undefined): AgentMessage | null { +export function fromApiMessage( + input: AgentsAgentChatMessage | undefined, + chatId: string, + organizationId: string | undefined, +): AgentMessage | null { if (!input || !input.id) return null; + const messageId = input.id; return { - id: input.id, + id: messageId, role: input.role ?? "", content: input.content ?? "", toolName: input.toolName ?? "", toolCallId: input.toolCallId ?? "", toolStatus: input.toolStatus ?? "", + images: (input.images ?? []) + // Index matches the image's position in the stored message, which the + // server endpoint uses to address it; map before filtering to preserve it. + .map((image, index) => ({ + mediaType: image.mediaType ?? "", + url: agentMessageImageUrl(chatId, messageId, index, organizationId), + })) + .filter((image) => Boolean(image.mediaType)), createdAt: input.createdAt ?? null, }; } + +function agentMessageImageUrl( + chatId: string, + messageId: string, + index: number, + organizationId: string | undefined, +): string { + const query = organizationId ? `?organization_id=${encodeURIComponent(organizationId)}` : ""; + return `/api/v1/agents/chats/${chatId}/messages/${messageId}/images/${index}${query}`; +} diff --git a/web_src/src/hooks/useAgentChats.ts b/web_src/src/hooks/useAgentChats.ts index 29017b579f..06dbdcbf1e 100644 --- a/web_src/src/hooks/useAgentChats.ts +++ b/web_src/src/hooks/useAgentChats.ts @@ -13,7 +13,13 @@ import { agentsSendAgentChatMessage, } from "@/api-client/sdk.gen"; import type { AgentMode } from "@/components/AgentSidebar/agentMode"; -import { fromApiChat, fromApiMessage, type AgentChat, type AgentMessage } from "@/components/CanvasToolSidebar/types"; +import { + fromApiChat, + fromApiMessage, + type AgentChat, + type AgentMessage, + type AgentOutgoingImage, +} from "@/components/CanvasToolSidebar/types"; import { withOrganizationHeader } from "@/lib/withOrganizationHeader"; export const agentChatKeys = { @@ -60,7 +66,9 @@ export function useAgentChatMessages(chatId: string | null, organizationId: stri query: { beforeId: pageParam || undefined, limit: PAGE_SIZE }, }), ); - const messages = (response.data?.messages ?? []).map(fromApiMessage).filter((m): m is AgentMessage => m !== null); + const messages = (response.data?.messages ?? []) + .map((message) => fromApiMessage(message, chatId ?? "", organizationId)) + .filter((m): m is AgentMessage => m !== null); return { messages, hasMore: Boolean(response.data?.hasMore) }; }, getNextPageParam: (lastPage) => { @@ -73,15 +81,29 @@ export function useAgentChatMessages(chatId: string | null, organizationId: stri export function useSendAgentChatMessage(organizationId: string | undefined, _canvasId: string | undefined) { const queryClient = useQueryClient(); return useMutation({ - mutationFn: async ({ chatId, content, mode }: { chatId: string; content: string; mode?: AgentMode }) => { + mutationFn: async ({ + chatId, + content, + mode, + images, + }: { + chatId: string; + content: string; + mode?: AgentMode; + images?: AgentOutgoingImage[]; + }) => { const response = await agentsSendAgentChatMessage( withOrganizationHeader({ organizationId, path: { chatId }, - body: { content, mode: mode ? agentModeToApiMode[mode] : undefined }, + body: { + content, + mode: mode ? agentModeToApiMode[mode] : undefined, + images: images && images.length > 0 ? images : undefined, + }, }), ); - return fromApiMessage(response.data?.message); + return fromApiMessage(response.data?.message, chatId, organizationId); }, onSuccess: (data, variables) => { if (data) upsertAgentMessageInCache(queryClient, variables.chatId, data);