diff --git a/backend/internal/handler/available_channel_handler.go b/backend/internal/handler/available_channel_handler.go index 8982b80defc..e15a742f818 100644 --- a/backend/internal/handler/available_channel_handler.go +++ b/backend/internal/handler/available_channel_handler.go @@ -53,12 +53,15 @@ func (h *AvailableChannelHandler) featureEnabled(c *gin.Context) bool { // 订阅视觉加深),并用 RateMultiplier 作为默认倍率;用户专属倍率前端走 // /groups/rates,和 API 密钥页面保持一致。 type userAvailableGroup struct { - ID int64 `json:"id"` - Name string `json:"name"` - Platform string `json:"platform"` - SubscriptionType string `json:"subscription_type"` - RateMultiplier float64 `json:"rate_multiplier"` - IsExclusive bool `json:"is_exclusive"` + ID int64 `json:"id"` + Name string `json:"name"` + Platform string `json:"platform"` + SubscriptionType string `json:"subscription_type"` + RateMultiplier float64 `json:"rate_multiplier"` + IsExclusive bool `json:"is_exclusive"` + ImagePrice1K *float64 `json:"image_price_1k"` + ImagePrice2K *float64 `json:"image_price_2k"` + ImagePrice4K *float64 `json:"image_price_4k"` } // userSupportedModelPricing 用户可见的定价字段白名单。 @@ -219,6 +222,9 @@ func filterUserVisibleGroups( SubscriptionType: g.SubscriptionType, RateMultiplier: g.RateMultiplier, IsExclusive: g.IsExclusive, + ImagePrice1K: g.ImagePrice1K, + ImagePrice2K: g.ImagePrice2K, + ImagePrice4K: g.ImagePrice4K, }) } return visible diff --git a/backend/internal/service/channel_available.go b/backend/internal/service/channel_available.go index d2d24659a1a..86fd31d171b 100644 --- a/backend/internal/service/channel_available.go +++ b/backend/internal/service/channel_available.go @@ -19,6 +19,9 @@ type AvailableGroupRef struct { SubscriptionType string RateMultiplier float64 IsExclusive bool + ImagePrice1K *float64 + ImagePrice2K *float64 + ImagePrice4K *float64 } // AvailableChannel 可用渠道视图:用于「可用渠道」页面展示渠道基础信息 + @@ -34,6 +37,10 @@ type AvailableChannel struct { SupportedModels []SupportedModel } +type channelAccountMappingModelLister interface { + ListAccountMappingModelsByGroupIDs(ctx context.Context, groupIDs []int64) (map[int64][]string, error) +} + // ListAvailable 返回所有渠道的可用视图:每个渠道附带关联分组信息与支持模型列表。 // // 支持模型通过 (*Channel).SupportedModels() 计算(mapping ∪ pricing 并联)。 @@ -56,6 +63,7 @@ func (s *ChannelService) ListAvailable(ctx context.Context) ([]AvailableChannel, return nil, fmt.Errorf("list active groups: %w", err) } groupByID := make(map[int64]AvailableGroupRef, len(groups)) + groupModelsListByID := make(map[int64]GroupModelsListConfig, len(groups)) for i := range groups { g := groups[i] groupByID[g.ID] = AvailableGroupRef{ @@ -65,15 +73,32 @@ func (s *ChannelService) ListAvailable(ctx context.Context) ([]AvailableChannel, SubscriptionType: g.SubscriptionType, RateMultiplier: g.RateMultiplier, IsExclusive: g.IsExclusive, + ImagePrice1K: g.ImagePrice1K, + ImagePrice2K: g.ImagePrice2K, + ImagePrice4K: g.ImagePrice4K, + } + groupModelsListByID[g.ID] = g.ModelsListConfig + } + + accountMappedModelsByGroupID := map[int64][]string{} + if lister, ok := s.repo.(channelAccountMappingModelLister); ok { + groupIDs := collectActiveGroupIDs(groupByID) + if len(groupIDs) > 0 { + accountMappedModelsByGroupID, err = lister.ListAccountMappingModelsByGroupIDs(ctx, groupIDs) + if err != nil { + return nil, fmt.Errorf("list account mapping models: %w", err) + } } } out := make([]AvailableChannel, 0, len(channels)) + channelLinkedGroupIDs := make(map[int64]struct{}, len(groupByID)) for i := range channels { ch := &channels[i] groups := make([]AvailableGroupRef, 0, len(ch.GroupIDs)) for _, gid := range ch.GroupIDs { if ref, ok := groupByID[gid]; ok { + channelLinkedGroupIDs[gid] = struct{}{} groups = append(groups, ref) } } @@ -82,6 +107,9 @@ func (s *ChannelService) ListAvailable(ctx context.Context) ([]AvailableChannel, ch.normalizeBillingModelSource() supported := ch.SupportedModels() + if !ch.RestrictModels { + supported = appendAccountMappedSupportedModels(supported, groups, accountMappedModelsByGroupID) + } s.fillGlobalPricingFallback(supported) out = append(out, AvailableChannel{ @@ -95,6 +123,7 @@ func (s *ChannelService) ListAvailable(ctx context.Context) ([]AvailableChannel, SupportedModels: supported, }) } + out = append(out, s.buildDirectGroupAvailableChannels(groupByID, groupModelsListByID, channelLinkedGroupIDs, accountMappedModelsByGroupID)...) sort.SliceStable(out, func(i, j int) bool { return strings.ToLower(out[i].Name) < strings.ToLower(out[j].Name) @@ -102,6 +131,106 @@ func (s *ChannelService) ListAvailable(ctx context.Context) ([]AvailableChannel, return out, nil } +func collectActiveGroupIDs(groupByID map[int64]AvailableGroupRef) []int64 { + groupIDs := make([]int64, 0, len(groupByID)) + for id := range groupByID { + groupIDs = append(groupIDs, id) + } + sort.Slice(groupIDs, func(i, j int) bool { return groupIDs[i] < groupIDs[j] }) + return groupIDs +} + +func (s *ChannelService) buildDirectGroupAvailableChannels( + groupByID map[int64]AvailableGroupRef, + groupModelsListByID map[int64]GroupModelsListConfig, + channelLinkedGroupIDs map[int64]struct{}, + mappedByGroupID map[int64][]string, +) []AvailableChannel { + out := make([]AvailableChannel, 0) + groupIDs := collectActiveGroupIDs(groupByID) + for _, groupID := range groupIDs { + if _, linked := channelLinkedGroupIDs[groupID]; linked { + continue + } + group := groupByID[groupID] + models := directGroupSupportedModels(group, groupModelsListByID[groupID], mappedByGroupID[groupID]) + if len(models) == 0 { + continue + } + s.fillGlobalPricingFallback(models) + out = append(out, AvailableChannel{ + ID: -groupID, + Name: group.Name, + Description: "分组直连", + Status: StatusActive, + RestrictModels: groupModelsListByID[groupID].Enabled, + Groups: []AvailableGroupRef{group}, + SupportedModels: models, + }) + } + return out +} + +func directGroupSupportedModels(group AvailableGroupRef, cfg GroupModelsListConfig, mappedModels []string) []SupportedModel { + modelNames := mappedModels + if cfg.Enabled && len(cfg.Models) > 0 { + modelNames = cfg.Models + } + seen := make(map[string]struct{}, len(modelNames)) + out := make([]SupportedModel, 0, len(modelNames)) + for _, model := range modelNames { + name := strings.TrimSpace(model) + if name == "" { + continue + } + key := strings.ToLower(name) + if _, ok := seen[key]; ok { + continue + } + seen[key] = struct{}{} + out = append(out, SupportedModel{Name: name, Platform: group.Platform}) + } + sort.SliceStable(out, func(i, j int) bool { return strings.ToLower(out[i].Name) < strings.ToLower(out[j].Name) }) + return out +} + +func appendAccountMappedSupportedModels(supported []SupportedModel, groups []AvailableGroupRef, mappedByGroupID map[int64][]string) []SupportedModel { + if len(groups) == 0 || len(mappedByGroupID) == 0 { + return supported + } + type modelKey struct { + platform string + model string + } + seen := make(map[modelKey]struct{}, len(supported)) + for _, model := range supported { + seen[modelKey{platform: model.Platform, model: strings.ToLower(model.Name)}] = struct{}{} + } + out := append([]SupportedModel(nil), supported...) + for _, group := range groups { + models := mappedByGroupID[group.ID] + for _, model := range models { + name := strings.TrimSpace(model) + if name == "" { + continue + } + key := modelKey{platform: group.Platform, model: strings.ToLower(name)} + if _, ok := seen[key]; ok { + continue + } + seen[key] = struct{}{} + out = append(out, SupportedModel{Name: name, Platform: group.Platform}) + } + } + sort.SliceStable(out, func(i, j int) bool { + if out[i].Platform != out[j].Platform { + return out[i].Platform < out[j].Platform + } + return strings.ToLower(out[i].Name) < strings.ToLower(out[j].Name) + }) + return out +} + // fillGlobalPricingFallback 对未命中渠道定价的支持模型,从全局 LiteLLM 数据合成一份 // 展示用定价。仅用于「可用渠道」展示,不影响真实计费链路。 // @@ -119,6 +248,10 @@ func (s *ChannelService) fillGlobalPricingFallback(models []SupportedModel) { continue } lp := s.pricingService.GetModelPricing(models[i].Name) + if lp == nil && isOpenAIImageGenerationModel(models[i].Name) { + models[i].Pricing = synthesizeDefaultImagePricing(models[i].Pricing) + continue + } if lp == nil { continue } @@ -171,12 +304,17 @@ func synthesizePricingFromLiteLLM(lp *LiteLLMModelPricing, existing *ChannelMode } if mode == BillingModeImage || mode == BillingModePerRequest { + perRequestPrice := nonZeroPtr(lp.OutputCostPerImage) + if mode == BillingModeImage && perRequestPrice == nil { + perRequestPrice = availableFloat64Ptr(defaultImagePrice1K) + } return &ChannelModelPricing{ BillingMode: mode, - PerRequestPrice: nonZeroPtr(lp.OutputCostPerImage), + PerRequestPrice: perRequestPrice, ImageOutputPrice: nonZeroPtr(lp.OutputCostPerImageToken), InputPrice: nonZeroPtr(lp.InputCostPerToken), OutputPrice: nonZeroPtr(lp.OutputCostPerToken), + Intervals: synthesizeImagePriceTiers(mode, perRequestPrice), } } return &ChannelModelPricing{ @@ -189,9 +327,42 @@ func synthesizePricingFromLiteLLM(lp *LiteLLMModelPricing, existing *ChannelMode } } +const defaultImagePrice1K = 0.134 + +func synthesizeDefaultImagePricing(existing *ChannelModelPricing) *ChannelModelPricing { + mode := BillingModeImage + if existing != nil && existing.BillingMode != "" { + mode = existing.BillingMode + } + base := availableFloat64Ptr(defaultImagePrice1K) + return &ChannelModelPricing{ + BillingMode: mode, + PerRequestPrice: base, + Intervals: synthesizeImagePriceTiers(mode, base), + } +} + func nonZeroPtr(v float64) *float64 { if v == 0 { return nil } return &v } + +func availableFloat64Ptr(v float64) *float64 { + return &v +} + +func synthesizeImagePriceTiers(mode BillingMode, base *float64) []PricingInterval { + if mode != BillingModeImage || base == nil { + return nil + } + price1K := *base + price2K := price1K * 1.5 + price4K := price1K * 2 + return []PricingInterval{ + {TierLabel: "1K", PerRequestPrice: &price1K, SortOrder: 1}, + {TierLabel: "2K", PerRequestPrice: &price2K, SortOrder: 2}, + {TierLabel: "4K", PerRequestPrice: &price4K, SortOrder: 3}, + } +} diff --git a/backend/internal/service/channel_available_test.go b/backend/internal/service/channel_available_test.go index d59e587ecd5..8f058382544 100644 --- a/backend/internal/service/channel_available_test.go +++ b/backend/internal/service/channel_available_test.go @@ -176,6 +176,80 @@ func TestListAvailable_DefaultsEmptyBillingModelSource(t *testing.T) { require.Equal(t, BillingModelSourceUpstream, byName["explicit"]) } +func TestListAvailable_AppendsAccountMappedModelsWhenChannelIsUnrestricted(t *testing.T) { + channels := []Channel{{ + ID: 1, + Name: "open-channel", + Status: StatusActive, + RestrictModels: false, + GroupIDs: []int64{10}, + ModelPricing: []ChannelModelPricing{{ + Platform: "anthropic", + Models: []string{"claude-sonnet-4-6"}, + BillingMode: BillingModeToken, + InputPrice: testPtrFloat64(1e-6), + }}, + }} + repo := &mockChannelRepository{ + listAllFn: func(ctx context.Context) ([]Channel, error) { return channels, nil }, + listAccountMappingModelsByGroupIDsFn: func(ctx context.Context, groupIDs []int64) (map[int64][]string, error) { + require.ElementsMatch(t, []int64{10}, groupIDs) + return map[int64][]string{ + 10: {"claude-opus-4-5", "claude-sonnet-4-6"}, + }, nil + }, + } + groupRepo := &stubGroupRepoForAvailable{ + activeGroups: []Group{{ID: 10, Name: "anthropic-go", Platform: "anthropic"}}, + } + svc := NewChannelService(repo, groupRepo, nil, nil) + + out, err := svc.ListAvailable(context.Background()) + require.NoError(t, err) + require.Len(t, out, 1) + + byName := make(map[string]SupportedModel) + for _, model := range out[0].SupportedModels { + byName[model.Name] = model + } + require.Contains(t, byName, "claude-sonnet-4-6") + require.NotNil(t, byName["claude-sonnet-4-6"].Pricing, "channel pricing keeps priority") + require.Contains(t, byName, "claude-opus-4-5") + require.Equal(t, "anthropic", byName["claude-opus-4-5"].Platform) + require.Len(t, out[0].SupportedModels, 2, "mapped duplicate should not be added twice") +} + +func TestListAvailable_DoesNotAppendAccountMappedModelsWhenChannelRestrictsModels(t *testing.T) { + channels := []Channel{{ + ID: 1, + Name: "restricted-channel", + Status: StatusActive, + RestrictModels: true, + GroupIDs: []int64{10}, + ModelPricing: []ChannelModelPricing{{ + Platform: "anthropic", + Models: []string{"claude-sonnet-4-6"}, + BillingMode: BillingModeToken, + }}, + }} + repo := &mockChannelRepository{ + listAllFn: func(ctx context.Context) ([]Channel, error) { return channels, nil }, + listAccountMappingModelsByGroupIDsFn: func(ctx context.Context, groupIDs []int64) (map[int64][]string, error) { + return map[int64][]string{10: {"claude-opus-4-5"}}, nil + }, + } + groupRepo := &stubGroupRepoForAvailable{ + activeGroups: []Group{{ID: 10, Name: "anthropic-go", Platform: "anthropic"}}, + } + svc := NewChannelService(repo, groupRepo, nil, nil) + + out, err := svc.ListAvailable(context.Background()) + require.NoError(t, err) + require.Len(t, out, 1) + require.Len(t, out[0].SupportedModels, 1) + require.Equal(t, "claude-sonnet-4-6", out[0].SupportedModels[0].Name) +} + func TestPricingNeedsFallback(t *testing.T) { tests := []struct { name string @@ -221,13 +295,19 @@ func TestSynthesizePricingFromLiteLLM_ImageGenerationMode(t *testing.T) { // LiteLLM mode=image_generation 且渠道未声明模式时,按 image 合成。 lp := &LiteLLMModelPricing{ Mode: "image_generation", + OutputCostPerImage: 0.134, OutputCostPerImageToken: 4e-5, } got := synthesizePricingFromLiteLLM(lp, nil) require.NotNil(t, got) require.Equal(t, BillingModeImage, got.BillingMode) - require.Nil(t, got.PerRequestPrice) + require.NotNil(t, got.PerRequestPrice) + require.InDelta(t, 0.134, *got.PerRequestPrice, 1e-12) require.NotNil(t, got.ImageOutputPrice) + require.Len(t, got.Intervals, 3) + require.Equal(t, "2K", got.Intervals[1].TierLabel) + require.NotNil(t, got.Intervals[1].PerRequestPrice) + require.InDelta(t, 0.201, *got.Intervals[1].PerRequestPrice, 1e-12) } func TestSynthesizePricingFromLiteLLM_RespectsExistingChannelMode(t *testing.T) { @@ -286,6 +366,34 @@ func TestFillGlobalPricingFallback_EmptyPricingFillsFromLiteLLM(t *testing.T) { require.Equal(t, BillingModeImage, models[0].Pricing.BillingMode) require.NotNil(t, models[0].Pricing.ImageOutputPrice) require.InDelta(t, 4e-5, *models[0].Pricing.ImageOutputPrice, 1e-12) + require.NotNil(t, models[0].Pricing.PerRequestPrice) + require.InDelta(t, 0.134, *models[0].Pricing.PerRequestPrice, 1e-12) + require.Len(t, models[0].Pricing.Intervals, 3) + require.Equal(t, "4K", models[0].Pricing.Intervals[2].TierLabel) + require.NotNil(t, models[0].Pricing.Intervals[2].PerRequestPrice) + require.InDelta(t, 0.268, *models[0].Pricing.Intervals[2].PerRequestPrice, 1e-12) +} + +func TestFillGlobalPricingFallback_OpenAIImageWithoutLiteLLMUsesDefaultImagePrice(t *testing.T) { + pricingSvc := newStubPricingServiceFromMap(map[string]*LiteLLMModelPricing{}) + svc := &ChannelService{pricingService: pricingSvc} + + models := []SupportedModel{ + { + Name: "gpt-image-2-2K", + Platform: "openai", + Pricing: &ChannelModelPricing{BillingMode: BillingModeImage}, + }, + } + svc.fillGlobalPricingFallback(models) + require.NotNil(t, models[0].Pricing) + require.Equal(t, BillingModeImage, models[0].Pricing.BillingMode) + require.NotNil(t, models[0].Pricing.PerRequestPrice) + require.InDelta(t, 0.134, *models[0].Pricing.PerRequestPrice, 1e-12) + require.Len(t, models[0].Pricing.Intervals, 3) + require.Equal(t, "2K", models[0].Pricing.Intervals[1].TierLabel) + require.NotNil(t, models[0].Pricing.Intervals[1].PerRequestPrice) + require.InDelta(t, 0.201, *models[0].Pricing.Intervals[1].PerRequestPrice, 1e-12) } func TestFillGlobalPricingFallback_KeepsExistingPrice(t *testing.T) { diff --git a/frontend/src/api/channels.ts b/frontend/src/api/channels.ts index 8962af2c4d8..3aeeafe2d8e 100644 --- a/frontend/src/api/channels.ts +++ b/frontend/src/api/channels.ts @@ -16,6 +16,9 @@ export interface UserAvailableGroup { rate_multiplier: number /** true = 专属分组(小范围授权);false = 公开分组。 */ is_exclusive: boolean + image_price_1k?: number | null + image_price_2k?: number | null + image_price_4k?: number | null } export interface UserPricingInterval { diff --git a/frontend/src/components/channels/AvailableChannelsTable.vue b/frontend/src/components/channels/AvailableChannelsTable.vue index 5b9c0eba207..c9ad0fb3cd4 100644 --- a/frontend/src/components/channels/AvailableChannelsTable.vue +++ b/frontend/src/components/channels/AvailableChannelsTable.vue @@ -1,159 +1,296 @@ + + diff --git a/frontend/src/components/channels/SupportedModelChip.vue b/frontend/src/components/channels/SupportedModelChip.vue index 3fe32e2f75b..de60a0bf2ff 100644 --- a/frontend/src/components/channels/SupportedModelChip.vue +++ b/frontend/src/components/channels/SupportedModelChip.vue @@ -14,11 +14,7 @@ @focusout="onLeave" tabindex="0" > - + @@ -164,8 +160,7 @@ import { // 复用 api/channels.ts 的用户侧最小形态 DTO。 // admin 侧 ChannelModelPricing 字段更多,但结构上是用户 DTO 的超集,admin 视图传入可直接通过结构化子类型检查。 import type { UserPricingInterval, UserSupportedModel } from '@/api/channels' -import PlatformIcon from '@/components/common/PlatformIcon.vue' -import type { GroupPlatform } from '@/types' +import ModelIcon from '@/components/common/ModelIcon.vue' import { platformBadgeClass, platformBorderClass, platformBadgeLightClass } from '@/utils/platformColors' const props = withDefaults( diff --git a/frontend/src/components/common/ModelIcon.vue b/frontend/src/components/common/ModelIcon.vue index 2a05bf71fac..76b3b8b1d5c 100644 --- a/frontend/src/components/common/ModelIcon.vue +++ b/frontend/src/components/common/ModelIcon.vue @@ -19,18 +19,20 @@