Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 12 additions & 6 deletions backend/internal/handler/available_channel_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,12 +53,15 @@ func (h *AvailableChannelHandler) featureEnabled(c *gin.Context) bool {
// 订阅视觉加深),并用 RateMultiplier 作为默认倍率;用户专属倍率前端走
// /groups/rates,和 API 密钥页面保持一致。
type userAvailableGroup struct {
ID int64 `json:"id"`
Name string `json:"name"`
Platform string `json:"platform"`
SubscriptionType string `json:"subscription_type"`
RateMultiplier float64 `json:"rate_multiplier"`
IsExclusive bool `json:"is_exclusive"`
ID int64 `json:"id"`
Name string `json:"name"`
Platform string `json:"platform"`
SubscriptionType string `json:"subscription_type"`
RateMultiplier float64 `json:"rate_multiplier"`
IsExclusive bool `json:"is_exclusive"`
ImagePrice1K *float64 `json:"image_price_1k"`
ImagePrice2K *float64 `json:"image_price_2k"`
ImagePrice4K *float64 `json:"image_price_4k"`
}

// userSupportedModelPricing 用户可见的定价字段白名单。
Expand Down Expand Up @@ -219,6 +222,9 @@ func filterUserVisibleGroups(
SubscriptionType: g.SubscriptionType,
RateMultiplier: g.RateMultiplier,
IsExclusive: g.IsExclusive,
ImagePrice1K: g.ImagePrice1K,
ImagePrice2K: g.ImagePrice2K,
ImagePrice4K: g.ImagePrice4K,
})
}
return visible
Expand Down
173 changes: 172 additions & 1 deletion backend/internal/service/channel_available.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ type AvailableGroupRef struct {
SubscriptionType string
RateMultiplier float64
IsExclusive bool
ImagePrice1K *float64
ImagePrice2K *float64
ImagePrice4K *float64
}

// AvailableChannel 可用渠道视图:用于「可用渠道」页面展示渠道基础信息 +
Expand All @@ -34,6 +37,10 @@ type AvailableChannel struct {
SupportedModels []SupportedModel
}

type channelAccountMappingModelLister interface {
ListAccountMappingModelsByGroupIDs(ctx context.Context, groupIDs []int64) (map[int64][]string, error)
}

// ListAvailable 返回所有渠道的可用视图:每个渠道附带关联分组信息与支持模型列表。
//
// 支持模型通过 (*Channel).SupportedModels() 计算(mapping ∪ pricing 并联)。
Expand All @@ -56,6 +63,7 @@ func (s *ChannelService) ListAvailable(ctx context.Context) ([]AvailableChannel,
return nil, fmt.Errorf("list active groups: %w", err)
}
groupByID := make(map[int64]AvailableGroupRef, len(groups))
groupModelsListByID := make(map[int64]GroupModelsListConfig, len(groups))
for i := range groups {
g := groups[i]
groupByID[g.ID] = AvailableGroupRef{
Expand All @@ -65,15 +73,32 @@ func (s *ChannelService) ListAvailable(ctx context.Context) ([]AvailableChannel,
SubscriptionType: g.SubscriptionType,
RateMultiplier: g.RateMultiplier,
IsExclusive: g.IsExclusive,
ImagePrice1K: g.ImagePrice1K,
ImagePrice2K: g.ImagePrice2K,
ImagePrice4K: g.ImagePrice4K,
}
groupModelsListByID[g.ID] = g.ModelsListConfig
}

accountMappedModelsByGroupID := map[int64][]string{}
if lister, ok := s.repo.(channelAccountMappingModelLister); ok {
groupIDs := collectActiveGroupIDs(groupByID)
if len(groupIDs) > 0 {
accountMappedModelsByGroupID, err = lister.ListAccountMappingModelsByGroupIDs(ctx, groupIDs)
if err != nil {
return nil, fmt.Errorf("list account mapping models: %w", err)
}
}
}

out := make([]AvailableChannel, 0, len(channels))
channelLinkedGroupIDs := make(map[int64]struct{}, len(groupByID))
for i := range channels {
ch := &channels[i]
groups := make([]AvailableGroupRef, 0, len(ch.GroupIDs))
for _, gid := range ch.GroupIDs {
if ref, ok := groupByID[gid]; ok {
channelLinkedGroupIDs[gid] = struct{}{}
groups = append(groups, ref)
}
}
Expand All @@ -82,6 +107,9 @@ func (s *ChannelService) ListAvailable(ctx context.Context) ([]AvailableChannel,
ch.normalizeBillingModelSource()

supported := ch.SupportedModels()
if !ch.RestrictModels {
supported = appendAccountMappedSupportedModels(supported, groups, accountMappedModelsByGroupID)
}
s.fillGlobalPricingFallback(supported)

out = append(out, AvailableChannel{
Expand All @@ -95,13 +123,114 @@ func (s *ChannelService) ListAvailable(ctx context.Context) ([]AvailableChannel,
SupportedModels: supported,
})
}
out = append(out, s.buildDirectGroupAvailableChannels(groupByID, groupModelsListByID, channelLinkedGroupIDs, accountMappedModelsByGroupID)...)

sort.SliceStable(out, func(i, j int) bool {
return strings.ToLower(out[i].Name) < strings.ToLower(out[j].Name)
})
return out, nil
}

func collectActiveGroupIDs(groupByID map[int64]AvailableGroupRef) []int64 {
groupIDs := make([]int64, 0, len(groupByID))
for id := range groupByID {
groupIDs = append(groupIDs, id)
}
sort.Slice(groupIDs, func(i, j int) bool { return groupIDs[i] < groupIDs[j] })
return groupIDs
}

func (s *ChannelService) buildDirectGroupAvailableChannels(
groupByID map[int64]AvailableGroupRef,
groupModelsListByID map[int64]GroupModelsListConfig,
channelLinkedGroupIDs map[int64]struct{},
mappedByGroupID map[int64][]string,
) []AvailableChannel {
out := make([]AvailableChannel, 0)
groupIDs := collectActiveGroupIDs(groupByID)
for _, groupID := range groupIDs {
if _, linked := channelLinkedGroupIDs[groupID]; linked {
continue
}
group := groupByID[groupID]
models := directGroupSupportedModels(group, groupModelsListByID[groupID], mappedByGroupID[groupID])
if len(models) == 0 {
continue
}
s.fillGlobalPricingFallback(models)
out = append(out, AvailableChannel{
ID: -groupID,
Name: group.Name,
Description: "分组直连",
Status: StatusActive,
RestrictModels: groupModelsListByID[groupID].Enabled,
Groups: []AvailableGroupRef{group},
SupportedModels: models,
})
}
return out
}

func directGroupSupportedModels(group AvailableGroupRef, cfg GroupModelsListConfig, mappedModels []string) []SupportedModel {
modelNames := mappedModels
if cfg.Enabled && len(cfg.Models) > 0 {
modelNames = cfg.Models
}
seen := make(map[string]struct{}, len(modelNames))
out := make([]SupportedModel, 0, len(modelNames))
for _, model := range modelNames {
name := strings.TrimSpace(model)
if name == "" {
continue
}
key := strings.ToLower(name)
if _, ok := seen[key]; ok {
continue
}
seen[key] = struct{}{}
out = append(out, SupportedModel{Name: name, Platform: group.Platform})
}
sort.SliceStable(out, func(i, j int) bool { return strings.ToLower(out[i].Name) < strings.ToLower(out[j].Name) })
return out
}

func appendAccountMappedSupportedModels(supported []SupportedModel, groups []AvailableGroupRef, mappedByGroupID map[int64][]string) []SupportedModel {
if len(groups) == 0 || len(mappedByGroupID) == 0 {
return supported
}
type modelKey struct {
platform string
model string
}
seen := make(map[modelKey]struct{}, len(supported))
for _, model := range supported {
seen[modelKey{platform: model.Platform, model: strings.ToLower(model.Name)}] = struct{}{}
}
out := append([]SupportedModel(nil), supported...)
for _, group := range groups {
models := mappedByGroupID[group.ID]
for _, model := range models {
name := strings.TrimSpace(model)
if name == "" {
continue
}
key := modelKey{platform: group.Platform, model: strings.ToLower(name)}
if _, ok := seen[key]; ok {
continue
}
seen[key] = struct{}{}
out = append(out, SupportedModel{Name: name, Platform: group.Platform})
}
}
sort.SliceStable(out, func(i, j int) bool {
if out[i].Platform != out[j].Platform {
return out[i].Platform < out[j].Platform
}
return strings.ToLower(out[i].Name) < strings.ToLower(out[j].Name)
})
return out
}

// fillGlobalPricingFallback 对未命中渠道定价的支持模型,从全局 LiteLLM 数据合成一份
// 展示用定价。仅用于「可用渠道」展示,不影响真实计费链路。
//
Expand All @@ -119,6 +248,10 @@ func (s *ChannelService) fillGlobalPricingFallback(models []SupportedModel) {
continue
}
lp := s.pricingService.GetModelPricing(models[i].Name)
if lp == nil && isOpenAIImageGenerationModel(models[i].Name) {
models[i].Pricing = synthesizeDefaultImagePricing(models[i].Pricing)
continue
}
if lp == nil {
continue
}
Expand Down Expand Up @@ -171,12 +304,17 @@ func synthesizePricingFromLiteLLM(lp *LiteLLMModelPricing, existing *ChannelMode
}

if mode == BillingModeImage || mode == BillingModePerRequest {
perRequestPrice := nonZeroPtr(lp.OutputCostPerImage)
if mode == BillingModeImage && perRequestPrice == nil {
perRequestPrice = availableFloat64Ptr(defaultImagePrice1K)
}
return &ChannelModelPricing{
BillingMode: mode,
PerRequestPrice: nonZeroPtr(lp.OutputCostPerImage),
PerRequestPrice: perRequestPrice,
ImageOutputPrice: nonZeroPtr(lp.OutputCostPerImageToken),
InputPrice: nonZeroPtr(lp.InputCostPerToken),
OutputPrice: nonZeroPtr(lp.OutputCostPerToken),
Intervals: synthesizeImagePriceTiers(mode, perRequestPrice),
}
}
return &ChannelModelPricing{
Expand All @@ -189,9 +327,42 @@ func synthesizePricingFromLiteLLM(lp *LiteLLMModelPricing, existing *ChannelMode
}
}

const defaultImagePrice1K = 0.134

func synthesizeDefaultImagePricing(existing *ChannelModelPricing) *ChannelModelPricing {
mode := BillingModeImage
if existing != nil && existing.BillingMode != "" {
mode = existing.BillingMode
}
base := availableFloat64Ptr(defaultImagePrice1K)
return &ChannelModelPricing{
BillingMode: mode,
PerRequestPrice: base,
Intervals: synthesizeImagePriceTiers(mode, base),
}
}

func nonZeroPtr(v float64) *float64 {
if v == 0 {
return nil
}
return &v
}

func availableFloat64Ptr(v float64) *float64 {
return &v
}

func synthesizeImagePriceTiers(mode BillingMode, base *float64) []PricingInterval {
if mode != BillingModeImage || base == nil {
return nil
}
price1K := *base
price2K := price1K * 1.5
price4K := price1K * 2
return []PricingInterval{
{TierLabel: "1K", PerRequestPrice: &price1K, SortOrder: 1},
{TierLabel: "2K", PerRequestPrice: &price2K, SortOrder: 2},
{TierLabel: "4K", PerRequestPrice: &price4K, SortOrder: 3},
}
}
Loading
Loading