[BUILD-427] feat: Allow user to send arbitrary messages and receive response (#2145)

[BUILD-427] feat: Allow user to send arbitrary messages and receive response (#2145)

diff --git a/ankihub/ai/tests/test_views.py b/ankihub/ai/tests/test_views.py
index 41e9214..c384945 100644
--- a/ankihub/ai/tests/test_views.py
+++ b/ankihub/ai/tests/test_views.py
@@ -967,3 +967,59 @@ class TestUpdateModalDisplayUserSettingView:
             user.settings.flashcard_selector["show_unsuspend_confirmation_modal"]
             is False
         )
+
+
+def test_chatbot(user):
+    url = reverse("ai:chatbot")
+    client = Client()
+    client.force_login(user=user)
+
+    response = client.get(url)
+
+    assert response.status_code == 200
+    assert "ai/chatbot.html" in (t.name for t in response.templates)
+
+
+@patch("ankihub.ai.views.llm_conversation_stream")
+def test_stream_chatbot_response(mock_llm_conversation_stream, user):
+    class LLMResponse:
+        def __init__(self, content):
+            self.content = content
+
+    stream_response = [LLMResponse(content="answer 1"), LLMResponse(content="answer 2")]
+
+    def mock_stream(question):
+        yield from stream_response
+
+    url = reverse("ai:stream_chatbot_response")
+    client = Client()
+    client.force_login(user=user)
+
+    mock_llm_conversation_stream.side_effect = mock_stream
+
+    request_data = {"question": "test"}
+
+    response = client.post(url, data=request_data)
+
+    assert response.status_code == 200
+    assert response["Content-Type"] == "text/event-stream"
+    assert response["Cache-Control"] == "no-cache"
+    assert response["X-Accel-Buffering"] == "no"
+
+    messages = [line.decode("utf-8").strip() for line in response.streaming_content]
+    assert len(messages) > 0
+    assert messages == [answer.content for answer in stream_response]
+
+
+@patch("ankihub.ai.views.llm_conversation_stream")
+def test_stream_chatbot_with_invalid_question(mock_llm_conversation_stream, user):
+    url = reverse("ai:stream_chatbot_response")
+    client = Client()
+    client.force_login(user=user)
+
+    request_data = {"question": ""}
+
+    response = client.post(url, data=request_data)
+
+    mock_llm_conversation_stream.assert_not_called()
+    assert response.status_code == 400
diff --git a/ankihub/ai/urls.py b/ankihub/ai/urls.py
index 6fb0c12..25f7a68 100644
--- a/ankihub/ai/urls.py
+++ b/ankihub/ai/urls.py
@@ -80,4 +80,10 @@ urlpatterns = [
         views.update_modal_display_user_setting,
         name="update-modal-settings",
     ),
+    path("chatbot/", view=views.chatbot, name="chatbot"),
+    path(
+        "stream-chatbot-response/",
+        views.stream_chatbot_response,
+        name="stream_chatbot_response",
+    ),
 ]
diff --git a/ankihub/ai/use_cases.py b/ankihub/ai/use_cases.py
index e3c2c19..a3815bf 100644
--- a/ankihub/ai/use_cases.py
+++ b/ankihub/ai/use_cases.py
@@ -3,6 +3,8 @@ import uuid
 from dataclasses import dataclass
 from enum import Enum
 
+from langchain_openai import ChatOpenAI
+
 from ankihub.ai import operations
 from ankihub.ai.models import (
     Ada002SearchDocument,
@@ -223,3 +225,8 @@ def save_notes_action_from_tags_performed_search(
         action=NotesActionChoices.UNSUSPEND,
         created_through=NotesActionCreatedThroughChoices.STUDY_AIDS,
     )
+
+
+def llm_conversation_stream(question):  # pragma: no cover
+    llm = ChatOpenAI(temperature=0)
+    return llm.stream(question)
diff --git a/ankihub/ai/views.py b/ankihub/ai/views.py
index da9e626..f497372 100644
--- a/ankihub/ai/views.py
+++ b/ankihub/ai/views.py
@@ -6,7 +6,13 @@ import structlog
 from django.conf import settings
 from django.contrib.auth.decorators import login_required
 from django.core.paginator import EmptyPage, Paginator
-from django.http import HttpResponse, HttpResponseBadRequest, JsonResponse, QueryDict
+from django.http import (
+    HttpResponse,
+    HttpResponseBadRequest,
+    JsonResponse,
+    QueryDict,
+    StreamingHttpResponse,
+)
 from django.shortcuts import get_object_or_404, render
 from django.urls import reverse
 from django.views.decorators.http import require_http_methods
@@ -19,6 +25,7 @@ from ankihub.ai.models import SourceTypeChoices
 from ankihub.ai.operations import URLLoader, create_documents
 from ankihub.ai.use_cases import (
     SearchResult,
+    llm_conversation_stream,
     save_notes_action_from_tags_performed_search,
     semantic_similarity_search_from_documents,
     to_search_result,
@@ -454,3 +461,26 @@ def update_modal_display_user_setting(request):
         )
 
     return HttpResponse(status=204)
+
+
+@knox_token_or_login_required
+def chatbot(request):
+    return render(request, "ai/chatbot.html")
+
+
+@knox_token_or_login_required
+@require_http_methods(["POST"])
+def stream_chatbot_response(request):
+    user_question = request.POST.get("question")
+
+    if not user_question:
+        return HttpResponseBadRequest()
+
+    def message_stream():
+        for answer in llm_conversation_stream(user_question):
+            yield answer.content
+
+    response = StreamingHttpResponse(message_stream(), content_type="text/event-stream")
+    response["Cache-Control"] = "no-cache"
+    response["X-Accel-Buffering"] = "no"
+    return response
diff --git a/ankihub/templates/ai/chatbot.html b/ankihub/templates/ai/chatbot.html
new file mode 100644
index 0000000..0521ed2
--- /dev/null
+++ b/ankihub/templates/ai/chatbot.html
@@ -0,0 +1,81 @@
+{% extends "base_embed.html" %}
+{% block title %}Chat{% endblock %}
+{% load custom_filters %}
+{% load waffle_tags %}
+{% load custom_tags %}
+
+{% load static %}
+
+{% block css %}
+  {{ block.super }}
+  <link href="{% static 'css/tagify_custom.css' %}" rel="stylesheet">
+{% endblock css %}
+
+{% block content %}
+  <div class="ml-5 mt-5 mr-5">
+    <div class="container mx-auto p-4">
+      <h1 class="text-3xl font-bold mb-6 text-center">Chatbot</h1>
+
+      <div class="bg-white rounded-lg shadow-lg p-6">
+          <div id="output" class="h-80 overflow-y-auto mb-4 p-4 bg-gray-50 border border-gray-200 rounded"></div>
+
+          <form id="postForm" class="flex">
+              <input type="text" id="questionData" name="question" class="flex-1 p-2 border border-gray-300 rounded-l-md focus:outline-none focus:ring-2 focus:ring-blue-500" placeholder="Type your message..." required>
+              <button type="submit" class="p-2 bg-blue-500 text-white rounded-r-md hover:bg-blue-600 focus:outline-none focus:ring-2 focus:ring-blue-500">Send</button>
+          </form>
+      </div>
+  </div>
+
+  <script>
+      document.getElementById('postForm').addEventListener('submit', async function(event) {
+          event.preventDefault();
+
+          const outputElement = document.getElementById('output');
+          const inputData = document.getElementById('questionData');
+          const question = inputData.value;
+          inputData.value = '';
+
+          outputElement.innerHTML += `<div class="mb-2"><strong>You:</strong> ${question}</div>`;
+
+          const formData = new FormData();
+          formData.append('question', question);
+
+          try {
+              const response = await fetch(`{% url "ai:stream_chatbot_response" %}`, {
+                  method: 'POST',
+                  headers: {
+                      'Content-Type': 'application/x-www-form-urlencoded',
+                      'X-CSRFToken': `{{ csrf_token }}`
+                  },
+                  body: new URLSearchParams(formData)
+              });
+              if (!response.ok) {
+                  throw new Error('Network response was not ok');
+              }
+
+              const reader = response.body.getReader();
+              const decoder = new TextDecoder();
+
+              let botMessage = document.createElement('div');
+              botMessage.classList.add('mb-2');
+              const strongElement = document.createElement('strong');
+              strongElement.textContent = 'Bot: ';
+
+              botMessage.appendChild(strongElement);
+              outputElement.appendChild(botMessage)
+              while (true) {

[... diff too long, it was truncated ...]

GitHub
sha: 1e638612d902193463b2edfbcaba4db64b1930b0