|
| 1 | +/** |
| 2 | + * Model client adapter — routes inference to Ollama (fine-tuned) or |
| 3 | + * OpenAI (fallback) based on MODEL_BACKEND env var. |
| 4 | + */ |
| 5 | + |
| 6 | +import { generateText } from "ai" |
| 7 | +import { createOpenAI } from "@ai-sdk/openai" |
| 8 | + |
| 9 | +const MODEL_BACKEND = process.env.MODEL_BACKEND || "openai" |
| 10 | +const SCHOOL_CODE = process.env.SCHOOL_CODE || "bishop-state" |
| 11 | +const OLLAMA_BASE_URL = process.env.OLLAMA_BASE_URL || "http://localhost:11434" |
| 12 | +const MODEL_SIZE = process.env.MODEL_SIZE || "9b" |
| 13 | + |
| 14 | +let _openai: ReturnType<typeof createOpenAI> | null = null |
| 15 | + |
| 16 | +function getOpenAI() { |
| 17 | + if (!_openai) { |
| 18 | + _openai = createOpenAI({ apiKey: process.env.OPENAI_API_KEY || "" }) |
| 19 | + } |
| 20 | + return _openai |
| 21 | +} |
| 22 | + |
| 23 | +async function callOllama(model: string, prompt: string, maxTokens: number): Promise<string> { |
| 24 | + const response = await fetch(`${OLLAMA_BASE_URL}/api/generate`, { |
| 25 | + method: "POST", |
| 26 | + headers: { "Content-Type": "application/json" }, |
| 27 | + body: JSON.stringify({ |
| 28 | + model, |
| 29 | + prompt, |
| 30 | + stream: false, |
| 31 | + options: { |
| 32 | + temperature: 0.3, |
| 33 | + num_predict: maxTokens, |
| 34 | + }, |
| 35 | + }), |
| 36 | + }) |
| 37 | + |
| 38 | + if (!response.ok) { |
| 39 | + throw new Error(`Ollama error: ${response.status} ${response.statusText}`) |
| 40 | + } |
| 41 | + |
| 42 | + const data = await response.json() |
| 43 | + return data.response |
| 44 | +} |
| 45 | + |
| 46 | +async function generate( |
| 47 | + task: "explainer" | "summarizer", |
| 48 | + prompt: string, |
| 49 | + maxTokens: number, |
| 50 | +): Promise<string> { |
| 51 | + if (MODEL_BACKEND === "ollama") { |
| 52 | + const model = `${SCHOOL_CODE}-${task}:${MODEL_SIZE}` |
| 53 | + return callOllama(model, prompt, maxTokens) |
| 54 | + } |
| 55 | + const result = await generateText({ |
| 56 | + model: getOpenAI()("gpt-4o-mini"), |
| 57 | + prompt, |
| 58 | + maxOutputTokens: maxTokens, |
| 59 | + }) |
| 60 | + return result.text |
| 61 | +} |
| 62 | + |
| 63 | +/** |
| 64 | + * Generate a course pairing explanation. |
| 65 | + */ |
| 66 | +export async function generateExplanation( |
| 67 | + prompt: string, |
| 68 | + maxTokens: number = 320, |
| 69 | +): Promise<string> { |
| 70 | + return generate("explainer", prompt, maxTokens) |
| 71 | +} |
| 72 | + |
| 73 | +/** |
| 74 | + * Generate a query result summary. |
| 75 | + */ |
| 76 | +export async function generateSummary( |
| 77 | + prompt: string, |
| 78 | + maxTokens: number = 200, |
| 79 | +): Promise<string> { |
| 80 | + return generate("summarizer", prompt, maxTokens) |
| 81 | +} |
0 commit comments