[BUILD-432] convert chatbot streaming response view to async (#2157)

[BUILD-432] convert chatbot streaming response view to async (#2157)

diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index c72d472..8c326bd 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -20,14 +20,14 @@ repos:
       - id: isort
 
   - repo: https://github.com/pycqa/flake8
-    rev: 7.0.0
+    rev: 7.1.0
     hooks:
       - id: flake8
         args: ["--config=setup.cfg"]
         additional_dependencies: [flake8-isort]
 
   - repo: https://github.com/adamchainz/django-upgrade
-    rev: 1.16.0
+    rev: 1.18.0
     hooks:
       - id: django-upgrade
         args: [--target-version, "4.2"]
diff --git a/ankihub/ai/tests/test_use_cases.py b/ankihub/ai/tests/test_use_cases.py
index df79282..a95dca6 100644
--- a/ankihub/ai/tests/test_use_cases.py
+++ b/ankihub/ai/tests/test_use_cases.py
@@ -1,7 +1,7 @@
 import json
 import pickle
 from decimal import Decimal
-from unittest.mock import patch
+from unittest.mock import MagicMock, patch
 
 import pytest
 from langchain_core.documents import Document
@@ -17,6 +17,7 @@ from ankihub.ai.use_cases import (
     Relevance,
     _log_search,
     get_category_from_value,
+    llm_conversation_stream,
     save_notes_action_from_tags_performed_search,
     semantic_similarity_search_from_anki_note_id,
     semantic_similarity_search_from_documents,
@@ -388,3 +389,16 @@ def test_save_notes_action_from_tags_performed_search_with_manually_unselected_f
 def test_get_category_from_value(score_value, expected_relevance):
     result = get_category_from_value(score_value)
     assert result == expected_relevance
+
+
+@patch("ankihub.ai.use_cases.ChatOpenAI")
+def test_llm_conversation_stream(MockedChatOpenAI):
+    question = "my question"
+    mocked_astream = MagicMock("mocked_astream")
+
+    MockedChatOpenAI.return_value.astream = mocked_astream
+
+    llm_conversation_stream(question)
+
+    MockedChatOpenAI.assert_called_once_with(temperature=0)
+    mocked_astream.assert_called_once_with(question)
diff --git a/ankihub/ai/tests/test_views.py b/ankihub/ai/tests/test_views.py
index bbd45ed..5d23636 100644
--- a/ankihub/ai/tests/test_views.py
+++ b/ankihub/ai/tests/test_views.py
@@ -1,3 +1,4 @@
+import asyncio
 import json
 import os
 import uuid
@@ -5,6 +6,7 @@ from unittest.mock import MagicMock, patch
 from urllib.parse import urlencode
 
 import pytest
+from django.contrib.auth import get_user_model
 from django.db.models import Case, FloatField, Value, When
 from django.test import Client
 from django.urls import reverse
@@ -28,6 +30,8 @@ from ankihub.decks.tests.factories import (
 )
 from config.settings.base import APPS_DIR, AWS_AI_FILES_BUCKET
 
+User = get_user_model()
+
 SIMILARITY_SCORE_TEST_VALUES = [
     0.0,
     0.01,
@@ -49,6 +53,18 @@ TEST_NOTE_IDS = [uuid.uuid4() for _ in range(len(SIMILARITY_SCORE_TEST_VALUES))]
 NOTE_IDS_WITH_SCORES = [(uuid.uuid4(), score) for score in SIMILARITY_SCORE_TEST_VALUES]
 
 
+class AsyncGenerator:
+    def __init__(self, elements):
+        self.elements = elements
+
+    def __aiter__(self):
+        return self._async_generator()
+
+    async def _async_generator(self):
+        for element in self.elements:
+            yield element
+
+
 @pytest.fixture
 def deck_and_note_embeddings():
     deck = DeckFactory()
@@ -992,6 +1008,7 @@ def test_chatbot_without_feature_flag_return_404(user):
     assert response.status_code == 404
 
 
+@pytest.mark.django_db
 @override_flag("chatbot", active=True)
 @patch("ankihub.ai.views.llm_conversation_stream")
 def test_stream_chatbot_response(mock_llm_conversation_stream, user):
@@ -999,12 +1016,18 @@ def test_stream_chatbot_response(mock_llm_conversation_stream, user):
         def __init__(self, content):
             self.content = content
 
-    stream_response = [LLMResponse(content="answer 1"), LLMResponse(content="answer 2")]
+    expected_answer_chunks = [
+        LLMResponse(content="answer 1"),
+        LLMResponse(content="answer 2"),
+    ]
+    stream_response = AsyncGenerator(expected_answer_chunks)
 
-    def mock_stream(question):
-        yield from stream_response
+    async def mock_stream(question):
+        async for response in stream_response:
+            yield response
 
     url = reverse("ai:stream_chatbot_response")
+
     client = Client()
     client.force_login(user=user)
 
@@ -1012,16 +1035,21 @@ def test_stream_chatbot_response(mock_llm_conversation_stream, user):
 
     request_data = {"question": "test"}
 
-    response = client.post(url, data=request_data)
+    response = client.post(f"{url}?", data=request_data)
 
     assert response.status_code == 200
     assert response["Content-Type"] == "text/event-stream"
     assert response["Cache-Control"] == "no-cache"
     assert response["X-Accel-Buffering"] == "no"
 
-    messages = [line.decode("utf-8").strip() for line in response.streaming_content]
+    async def _collect(async_gen):
+        return [item.decode("utf-8").strip() async for item in async_gen]
+
+    loop = asyncio.get_event_loop()
+    messages = loop.run_until_complete(_collect(response.streaming_content))
+
     assert len(messages) > 0
-    assert messages == [answer.content for answer in stream_response]
+    assert messages == [answer.content for answer in expected_answer_chunks]
 
 
 def test_stream_chatbot_response_without_feature_flag_returns_404(user):
diff --git a/ankihub/ai/use_cases.py b/ankihub/ai/use_cases.py
index a3815bf..f0608ed 100644
--- a/ankihub/ai/use_cases.py
+++ b/ankihub/ai/use_cases.py
@@ -229,4 +229,4 @@ def save_notes_action_from_tags_performed_search(
 
 def llm_conversation_stream(question):  # pragma: no cover
     llm = ChatOpenAI(temperature=0)
-    return llm.stream(question)
+    return llm.astream(question)
diff --git a/ankihub/ai/views.py b/ankihub/ai/views.py
index 7c36de9..426a088 100644
--- a/ankihub/ai/views.py
+++ b/ankihub/ai/views.py
@@ -3,9 +3,11 @@ import os
 from uuid import uuid4
 
 import structlog
+from asgiref.sync import sync_to_async
 from django.conf import settings
 from django.contrib.auth.decorators import login_required
 from django.core.paginator import EmptyPage, Paginator
+from django.db import transaction
 from django.http import (
     Http404,
     HttpResponse,
@@ -32,7 +34,10 @@ from ankihub.ai.use_cases import (
     semantic_similarity_search_from_documents,
     to_search_result,
 )
-from ankihub.common.decorators import knox_token_or_login_required
+from ankihub.common.decorators import (
+    async_knox_token_or_login_required,
+    knox_token_or_login_required,
+)
 from ankihub.common.services.aws import get_s3_client, get_textract_client
 from ankihub.decks.models import (
     Deck,
@@ -472,10 +477,11 @@ def chatbot(request):
     return render(request, "ai/chatbot.html")
 
 
-@knox_token_or_login_required
-@require_http_methods(["POST"])
-def stream_chatbot_response(request):
-    if not flag_is_active(request, "chatbot"):
+@transaction.non_atomic_requests
+@async_knox_token_or_login_required
+async def stream_chatbot_response(request):
+    print("FLAG IS ACTIVE", await sync_to_async(flag_is_active)(request, "chatbot"))
+    if not await sync_to_async(flag_is_active)(request, "chatbot"):
         raise Http404()
 
     user_question = request.POST.get("question")
@@ -483,8 +489,8 @@ def stream_chatbot_response(request):
     if not user_question:
         return HttpResponseBadRequest()
 
-    def message_stream():
-        for answer in llm_conversation_stream(user_question):
+    async def message_stream():
+        async for answer in llm_conversation_stream(user_question):
             yield answer.content
 
     response = StreamingHttpResponse(message_stream(), content_type="text/event-stream")
diff --git a/ankihub/common/decorators.py b/ankihub/common/decorators.py
index e02594f..7cd8ec6 100644
--- a/ankihub/common/decorators.py
+++ b/ankihub/common/decorators.py
@@ -1,8 +1,13 @@
+import asyncio
 from functools import wraps
+from urllib.parse import urlparse
 
+from asgiref.sync import sync_to_async

[... diff too long, it was truncated ...]

GitHub
sha: f117485b2a8303138bc5d0e0fe991bb093bb881e