Add support for reasoning to Ollama backend

This commit is contained in:
Kovid Goyal 2025-09-13 08:38:44 +05:30
parent 9a43f11b2b
commit 817f936a56
No known key found for this signature in database
GPG Key ID: 06BC317B515ACE7C
2 changed files with 19 additions and 16 deletions

View File

@ -41,25 +41,17 @@ class Model(NamedTuple):
family: str
families: Sequence[str]
modified_at: datetime.datetime
can_think: bool
@classmethod
def from_dict(cls, x: dict[str, object]) -> 'Model':
mid = x['model']
def from_dict(cls, x: dict[str, Any], details: dict[str, Any]) -> 'Model':
d = x.get('details', {})
return Model(
name=x['name'], id=mid, family=d.get('family', ''), families=d.get('families', ()),
modified_at=datetime.datetime.fromisoformat(x['modified_at'])
name=x['name'], id=x['model'], family=d.get('family', ''), families=d.get('families', ()),
modified_at=datetime.datetime.fromisoformat(x['modified_at']), can_think='thinking' in details['capabilities'],
)
def parse_models_list(entries: list[dict[str, Any]]) -> dict[str, Model]:
ans = {}
for entry in entries:
e = Model.from_dict(entry)
ans[e.id] = e
return ans
def api_url(path: str = '') -> str:
ans = pref('api_url') or OllamaAI.DEFAULT_URL
purl = urlparse(ans)
@ -72,7 +64,15 @@ def api_url(path: str = '') -> str:
@lru_cache(2)
def get_available_models() -> dict[str, Model]:
return parse_models_list(json.loads(download_data(api_url('api/tags')))['models'])
ans = {}
o = opener()
for model in json.loads(download_data(api_url('api/tags')))['models']:
rq = Request(api_url('api/show'), data=json.dumps({'model': model['model']}).encode(), method='POST')
with o.open(rq) as f:
details = json.loads(f.read())
e = Model.from_dict(model, details)
ans[e.id] = e
return ans
def does_model_exist_locally(model_id: str) -> bool:
@ -101,6 +101,8 @@ def model_choice_for_text() -> Model:
def chat_request(data: dict[str, Any], model: Model) -> Request:
data['stream'] = True
if model.can_think:
data['think'] = True
return Request(
api_url('api/chat'), data=json.dumps(data).encode('utf-8'),
headers=dict(headers()), method='POST')
@ -119,9 +121,10 @@ def as_chat_responses(d: dict[str, Any], model: Model) -> Iterator[ChatResponse]
if has_metadata and (dr := d['done_reason']) != 'stop':
yield ChatResponse(exception=ResultBlocked(custom_message=_('Result was blocked for reason: {}').format(dr)))
return
if has_metadata or content:
reasoning = msg.get('thinking') or ''
if has_metadata or content or reasoning:
yield ChatResponse(
type=ChatMessageType.assistant, content=content, has_metadata=has_metadata, model=model.id, plugin_name=OllamaAI.name)
type=ChatMessageType.assistant, reasoning=reasoning, content=content, has_metadata=has_metadata, model=model.id, plugin_name=OllamaAI.name)
def read_streaming_response(rq: Request, provider_name: str = 'AI provider') -> Iterator[dict[str, Any]]:

View File

@ -262,7 +262,7 @@ def develop_text_chat(
acc = StreamedResponseAccumulator()
messages = messages or (
ChatMessage(type=ChatMessageType.system, query='You are William Shakespeare.'),
ChatMessage('Give me twenty lines on my supremely beautiful wife.')
ChatMessage('Write twenty lines on my supremely beautiful wife. Assume she has honey gold skin and a brilliant smile.')
)
for x in text_chat(messages, use_model):
if x.exception: