From b598eec5f4ce6fa1780c23bd9c3611e1a44625e0 Mon Sep 17 00:00:00 2001 From: Haris-Dilawar Date: Tue, 25 Feb 2025 13:28:37 +0500 Subject: [PATCH] adding support for ollama latest models --- llmx/configs/config.default.yml | 22 +++++++++ llmx/generators/text/ollama_textgen.py | 68 ++++++++++++++++++++++++++ llmx/generators/text/textgen.py | 5 ++ pyproject.toml | 1 + setup.py | 2 +- tests/test_generators.py | 10 ++++ 6 files changed, 107 insertions(+), 1 deletion(-) create mode 100644 llmx/generators/text/ollama_textgen.py diff --git a/llmx/configs/config.default.yml b/llmx/configs/config.default.yml index c13730c..b361ee3 100644 --- a/llmx/configs/config.default.yml +++ b/llmx/configs/config.default.yml @@ -147,3 +147,25 @@ providers: model: uukuguy/speechless-llama2-hermes-orca-platypus-13b device_map: auto trust_remote_code: true + ollama: + name: Ollama + description: Local LLM models using the Ollama framework. + models: + - name: llama3.1 + max_tokens: 4096 + model: + provider: ollama + parameters: + model: llama3.1 + - name: qwen2.5-coder:3b + max_tokens: 4096 + model: + provider: ollama + parameters: + model: qwen2.5-coder:3b + - name: deepseek-r1 + max_tokens: 4096 + model: + provider: ollama + parameters: + model: deepseek-r1 \ No newline at end of file diff --git a/llmx/generators/text/ollama_textgen.py b/llmx/generators/text/ollama_textgen.py new file mode 100644 index 0000000..f86f98d --- /dev/null +++ b/llmx/generators/text/ollama_textgen.py @@ -0,0 +1,68 @@ +import ollama +import os +from typing import Union, List, Dict +from .base_textgen import TextGenerator +from ...datamodel import Message, TextGenerationConfig, TextGenerationResponse +from ...utils import cache_request, get_models_maxtoken_dict, num_tokens_from_messages +from dataclasses import asdict + + +class OllamaTextGenerator(TextGenerator): + def __init__( + self, + provider: str = "ollama", + model: str = "llama3.1", + models: Dict = None, + ): + super().__init__(provider=provider) + self.model_name = model + self.model_max_token_dict = get_models_maxtoken_dict(models) + + def generate( + self, + messages: Union[List[dict], str], + config: TextGenerationConfig = TextGenerationConfig(), + **kwargs, + ) -> TextGenerationResponse: + use_cache = config.use_cache + model = config.model or self.model_name + prompt_tokens = num_tokens_from_messages(messages) + max_tokens = max( + self.model_max_token_dict.get(model, 4096) - prompt_tokens - 10, 200 + ) + + + if isinstance(messages, list): + prompt = "\n".join([msg["content"] for msg in messages]) + else: + prompt = messages + + try: + response = ollama.chat( + model=model, + messages=[{"role": "user", "content": prompt}], + ) + + generated_text = response.message.content + + response_obj = TextGenerationResponse( + text=[Message(role="assistant", content=generated_text)], + logprobs=[], + config={"model": model, "max_tokens": max_tokens}, + usage={ + "prompt_tokens": prompt_tokens, + "completion_tokens": len(generated_text.split()), + "total_tokens": prompt_tokens + len(generated_text.split()), + }, + ) + + if use_cache: + cache_request(cache=self.cache, params=(prompt, model), values=asdict(response_obj)) + + return response_obj + + except Exception as e: + return TextGenerationResponse(text=[Message(role="error", content=f"⚠️ Ollama Error: {str(e)}")]) + + def count_tokens(self, text) -> int: + return len(text.split()) #Approximate token count for simplicity \ No newline at end of file diff --git a/llmx/generators/text/textgen.py b/llmx/generators/text/textgen.py index 3d86002..31871c1 100644 --- a/llmx/generators/text/textgen.py +++ b/llmx/generators/text/textgen.py @@ -3,6 +3,7 @@ from .palm_textgen import PalmTextGenerator from .cohere_textgen import CohereTextGenerator from .anthropic_textgen import AnthropicTextGenerator +from .ollama_textgen import OllamaTextGenerator import logging logger = logging.getLogger("llmx") @@ -19,6 +20,8 @@ def sanitize_provider(provider: str): return "hf" elif provider.lower() == "anthropic" or provider.lower() == "claude": return "anthropic" + elif provider.lower() == "ollama" or provider.lower() == "llama": + return "ollama" else: raise ValueError( f"Invalid provider '{provider}'. Supported providers are 'openai', 'hf', 'palm', 'cohere', and 'anthropic'." @@ -58,6 +61,8 @@ def llm(provider: str = None, **kwargs): return CohereTextGenerator(**kwargs) elif provider.lower() == "anthropic": return AnthropicTextGenerator(**kwargs) + elif provider.lower() == "ollama": + return OllamaTextGenerator(**kwargs) elif provider.lower() == "hf": try: import transformers diff --git a/pyproject.toml b/pyproject.toml index 4c8c56b..5a547f4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,6 +28,7 @@ dependencies = [ "anthropic", "typer", "pyyaml", + "ollama" ] optional-dependencies = {web = ["fastapi", "uvicorn"], transformers = ["transformers[torch]>=4.26","accelerate", "bitsandbytes"]} diff --git a/setup.py b/setup.py index 8ab824c..59f2303 100644 --- a/setup.py +++ b/setup.py @@ -1,2 +1,2 @@ from setuptools import setup -setup() \ No newline at end of file +setup() \ No newline at end of file diff --git a/tests/test_generators.py b/tests/test_generators.py index 4f4e59c..d731125 100644 --- a/tests/test_generators.py +++ b/tests/test_generators.py @@ -60,6 +60,16 @@ def test_cohere(): assert ("paris" in answer.lower()) assert len(cohere_response.text) == 2 +def test_ollama(): + ollama_gen = llm(provider="ollama") + config.model = "llama3.1" + ollama_response = ollama_gen.generate(messages, config=config) + answer = ollama_response.text[0].content + print(ollama_response.text[0].content) + + assert ("paris" in answer.lower()) + assert len(ollama_response.text) == 1 + @pytest.mark.skipif(os.environ.get("LLMX_RUNALL", None) is None or os.environ.get("LLMX_RUNALL", None) == "False", reason="takes too long")