Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions llmx/configs/config.default.yml
Original file line number Diff line number Diff line change
Expand Up @@ -147,3 +147,25 @@ providers:
model: uukuguy/speechless-llama2-hermes-orca-platypus-13b
device_map: auto
trust_remote_code: true
ollama:
name: Ollama
description: Local LLM models using the Ollama framework.
models:
- name: llama3.1
max_tokens: 4096
model:
provider: ollama
parameters:
model: llama3.1
- name: qwen2.5-coder:3b
max_tokens: 4096
model:
provider: ollama
parameters:
model: qwen2.5-coder:3b
- name: deepseek-r1
max_tokens: 4096
model:
provider: ollama
parameters:
model: deepseek-r1
68 changes: 68 additions & 0 deletions llmx/generators/text/ollama_textgen.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import ollama
import os
from typing import Union, List, Dict
from .base_textgen import TextGenerator
from ...datamodel import Message, TextGenerationConfig, TextGenerationResponse
from ...utils import cache_request, get_models_maxtoken_dict, num_tokens_from_messages
from dataclasses import asdict


class OllamaTextGenerator(TextGenerator):
def __init__(
self,
provider: str = "ollama",
model: str = "llama3.1",
models: Dict = None,
):
super().__init__(provider=provider)
self.model_name = model
self.model_max_token_dict = get_models_maxtoken_dict(models)

def generate(
self,
messages: Union[List[dict], str],
config: TextGenerationConfig = TextGenerationConfig(),
**kwargs,
) -> TextGenerationResponse:
use_cache = config.use_cache
model = config.model or self.model_name
prompt_tokens = num_tokens_from_messages(messages)
max_tokens = max(
self.model_max_token_dict.get(model, 4096) - prompt_tokens - 10, 200
)


if isinstance(messages, list):
prompt = "\n".join([msg["content"] for msg in messages])
else:
prompt = messages

try:
response = ollama.chat(
model=model,
messages=[{"role": "user", "content": prompt}],
)

generated_text = response.message.content

response_obj = TextGenerationResponse(
text=[Message(role="assistant", content=generated_text)],
logprobs=[],
config={"model": model, "max_tokens": max_tokens},
usage={
"prompt_tokens": prompt_tokens,
"completion_tokens": len(generated_text.split()),
"total_tokens": prompt_tokens + len(generated_text.split()),
},
)

if use_cache:
cache_request(cache=self.cache, params=(prompt, model), values=asdict(response_obj))

return response_obj

except Exception as e:
return TextGenerationResponse(text=[Message(role="error", content=f"⚠️ Ollama Error: {str(e)}")])

def count_tokens(self, text) -> int:
return len(text.split()) #Approximate token count for simplicity
5 changes: 5 additions & 0 deletions llmx/generators/text/textgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from .palm_textgen import PalmTextGenerator
from .cohere_textgen import CohereTextGenerator
from .anthropic_textgen import AnthropicTextGenerator
from .ollama_textgen import OllamaTextGenerator
import logging

logger = logging.getLogger("llmx")
Expand All @@ -19,6 +20,8 @@ def sanitize_provider(provider: str):
return "hf"
elif provider.lower() == "anthropic" or provider.lower() == "claude":
return "anthropic"
elif provider.lower() == "ollama" or provider.lower() == "llama":
return "ollama"
else:
raise ValueError(
f"Invalid provider '{provider}'. Supported providers are 'openai', 'hf', 'palm', 'cohere', and 'anthropic'."
Expand Down Expand Up @@ -58,6 +61,8 @@ def llm(provider: str = None, **kwargs):
return CohereTextGenerator(**kwargs)
elif provider.lower() == "anthropic":
return AnthropicTextGenerator(**kwargs)
elif provider.lower() == "ollama":
return OllamaTextGenerator(**kwargs)
elif provider.lower() == "hf":
try:
import transformers
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ dependencies = [
"anthropic",
"typer",
"pyyaml",
"ollama"
]
optional-dependencies = {web = ["fastapi", "uvicorn"], transformers = ["transformers[torch]>=4.26","accelerate", "bitsandbytes"]}

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from setuptools import setup
setup()
setup()
10 changes: 10 additions & 0 deletions tests/test_generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,16 @@ def test_cohere():
assert ("paris" in answer.lower())
assert len(cohere_response.text) == 2

def test_ollama():
ollama_gen = llm(provider="ollama")
config.model = "llama3.1"
ollama_response = ollama_gen.generate(messages, config=config)
answer = ollama_response.text[0].content
print(ollama_response.text[0].content)

assert ("paris" in answer.lower())
assert len(ollama_response.text) == 1


@pytest.mark.skipif(os.environ.get("LLMX_RUNALL", None) is None
or os.environ.get("LLMX_RUNALL", None) == "False", reason="takes too long")
Expand Down