import { tagged } from '@mirage/service-logging';
import * as rx from 'rxjs';
import { createAPIv2GRPCWebPromiseClient } from '../api_v2';
import { DocSummarizationApiV2 } from './gen/doc_summarization_connectweb';
import { DocumentId } from './gen/doc_summarization_pb';

import type { DocSummaryQnaResponse } from './types';

const logger = tagged('context_Observables');

export function createDocSummarizationObservable(
  resultId: DocumentId,
  template: string = '',
  modelName: string = '',
  useMapReduce: boolean = false,
  mapReduceConfigs: { [key: string]: number } = {},
): rx.Observable<DocSummaryQnaResponse> {
  return new rx.Observable<DocSummaryQnaResponse>((subscriber) => {
    const summarizationQnaClient = createAPIv2GRPCWebPromiseClient(
      DocSummarizationApiV2,
    );
    if (resultId) {
      (async () => {
        try {
          for await (const res of summarizationQnaClient.getDocSummary({
            id: resultId,
            template: template,
            modelName: modelName,
            useMapreduce: useMapReduce ? 'True' : 'False',
            mapreduceConfigs: mapReduceConfigs,
          })) {
            if (res?.summary) {
              subscriber.next({
                answer: res.summary,
                requestId: res.requestId,
              });
            }
          }
          subscriber.complete();
        } catch (error) {
          logger.error('createDocSummarizationObservable error', error);
          subscriber.error(error);
        }
      })();
    }
  });
}

export function createDocAnswerObservable(
  resultId: DocumentId,
  question: string,
  template: string = '',
  modelName: string = '',
  topkRAG: number = -1,
): rx.Observable<DocSummaryQnaResponse> {
  return new rx.Observable<DocSummaryQnaResponse>((subscriber) => {
    const summarizationQnaClient = createAPIv2GRPCWebPromiseClient(
      DocSummarizationApiV2,
    );
    if (resultId && question) {
      (async () => {
        try {
          for await (const res of summarizationQnaClient.getDocAnswer({
            id: resultId,
            question: question,
            template: template,
            modelName: modelName,
            topkRAG: topkRAG,
          })) {
            if (res?.answer) {
              subscriber.next({
                answer: res.answer,
                requestId: res.requestId,
              });
            }
          }
          subscriber.complete();
        } catch (error) {
          logger.error('createDocAnswerObservable error', error);
          subscriber.error(error);
        }
      })();
    }
  });
}
