import corpus from './corpus.mjs';

const modelname = "AgGPT";
const OutputLength = 15; 

function matMul(A, B) {
    const result = [];
    for (let i = 0; i < A.length; i++) {
        result[i] = [];
        for (let j = 0; j < B[0].length; j++) {
            result[i][j] = A[i].reduce((sum, a, k) => sum + a * B[k][j], 0);
        }
    }
    return result;
}

function softmax(x) {
    const maxVal = Math.max(...x);
    const exps = x.map(v => Math.exp(v - maxVal));
    const sum = exps.reduce((a, b) => a + b, 0);
    return exps.map(e => e / sum);
}

function selfAttention(Q, K, V) {
    const scores = Q.map(qRow =>
        K.map(kRow =>
            qRow.reduce((acc, q, idx) => acc + q * kRow[idx], 0)
        )
    );

    const attentionWeights = scores.map(row => softmax(row));

    return V[0].map((_, j) =>
        V.map((_, i) => 
            attentionWeights[i].reduce((sum, weight, k) => sum + weight * V[k][j], 0)
        )
    );
}

function multiHeadAttention(Q, K, V, numHeads) {
    const dModel = Q[0].length;
    const headSize = dModel / numHeads;
    const outputs = [];

    for (let head = 0; head < numHeads; head++) {
        const qHead = Q.map(row => row.slice(head * headSize, (head + 1) * headSize));
        const kHead = K.map(row => row.slice(head * headSize, (head + 1) * headSize));
        const vHead = V.map(row => row.slice(head * headSize, (head + 1) * headSize));

        outputs.push(selfAttention(qHead, kHead, vHead));
    }

    return outputs.flat();
}

function positionalEncoding(seqLen, dModel) {
    return Array.from({ length: seqLen }, (_, pos) =>
        Array.from({ length: dModel }, (_, i) =>
            i % 2 === 0 ?
                Math.sin(pos / Math.pow(10000, i / dModel)) :
                Math.cos(pos / Math.pow(10000, (i - 1) / dModel))
        )
    );
}

function addPositionalEncoding(embeddings, positionalEncodings) {
    return embeddings.map((row, i) =>
        row.map((val, j) => val + positionalEncodings[i][j])
    );
}

function feedForwardNetwork(x) {
    const W1 = [[1, 0], [0, 1]];
    const b1 = [0, 0];
    const W2 = [[1, 1], [1, 1]];
    const b2 = [0, 0];

    const hidden = matMul(x, W1).map((row, i) =>
        row.map(val => Math.max(0, val + b1[i]))
    );
    return matMul(hidden, W2).map((row, i) =>
        row.map(val => val + b2[i])
    );
}

function tokenize(text) {
    return text.toLowerCase().replace(/[.,!?]/g, '').split(/\s+/);
}

function embedTokens(tokens) {
    return tokens.map(() => [Math.random(), Math.random(), Math.random()]);
}

function buildNGramModels(corpus) {
    const bigramModel = {};
    const trigramModel = {};
    const words = tokenize(corpus);

    for (let i = 0; i < words.length - 1; i++) {
        const word1 = words[i];
        const word2 = words[i + 1];

        if (!bigramModel[word1]) bigramModel[word1] = [];
        bigramModel[word1].push(word2);
    }

    for (let i = 0; i < words.length - 2; i++) {
        const word1 = words[i];
        const word2 = words[i + 1];
        const word3 = words[i + 2];

        const bigram = `${word1} ${word2}`;

        if (!trigramModel[bigram]) trigramModel[bigram] = [];
        trigramModel[bigram].push(word3);
    }

    return { bigramModel, trigramModel };
}

function predictNextWord(text, models) {
    const { bigramModel, trigramModel } = models;
    const words = tokenize(text);

    if (words.length === 0) return '';

    if (words.length === 1) {
        const lastWord = words[0];
        if (bigramModel[lastWord]) {
            const nextWords = bigramModel[lastWord];
            return nextWords[Math.floor(Math.random() * nextWords.length)];
        }
    } else if (words.length >= 2) {
        const lastBigram = `${words[words.length - 2]} ${words[words.length - 1]}`;
        if (trigramModel[lastBigram]) {
            const nextWords = trigramModel[lastBigram];
            return nextWords[Math.floor(Math.random() * nextWords.length)];
        } else if (bigramModel[words[words.length - 1]]) {
            const nextWords = bigramModel[words[words.length - 1]];
            return nextWords[Math.floor(Math.random() * nextWords.length)];
        }
    }

    return '';
}

function predictNextWordWithAttention(text, nGramModels) {
    const tokens = tokenize(text);
    const dModel = 3;
    const embeddings = embedTokens(tokens);
    const positionalEncodings = positionalEncoding(tokens.length, dModel);
    const encodedEmbeddings = addPositionalEncoding(embeddings, positionalEncodings);

    const numHeads = 2;
    const attentionOutput = multiHeadAttention(encodedEmbeddings, encodedEmbeddings, encodedEmbeddings, numHeads);

    const ffOutput = feedForwardNetwork(attentionOutput);

    const nGramPrediction = predictNextWord(text, nGramModels);
    console.log('N-Gram Prediction:', nGramPrediction);

    return nGramPrediction;
}

function correctText(text) {
    text = text.charAt(0).toUpperCase() + text.slice(1);
    if (text.length > 0 && text.charAt(text.length - 1) !== '.') {
        text += '.';
    }
    return text.replace(/(?:^|\.\s)(\w)/g, (_, p1) => p1.toUpperCase());
}

function isSentenceComplete(sentence, corpus) {
    const trimmedSentence = sentence.trim();
    if (trimmedSentence.length === 0) return false;

    const endPunctuationRegex = /[.!?]$/;
    const hasWords = /\b\w+\b/.test(trimmedSentence);
    const endsWithPunctuation = endPunctuationRegex.test(trimmedSentence);
    const isReasonablyLong = trimmedSentence.length > 1;

    const isSimilarToCorpus = corpus.split('.').some(
        corpSentence => corpSentence.trim() === trimmedSentence || 
                        (corpSentence.trim().endsWith('.') && trimmedSentence.endsWith('.'))
    );

    return hasWords && (endsWithPunctuation || isReasonablyLong) && isSimilarToCorpus;
}

function generateSentence(startText, nGramModels, maxLength = 20) {
    let sentence = startText.trim();
    let currentText = startText.trim();

    while (sentence.split(' ').length < maxLength) {
        const nextWord = predictNextWord(currentText, nGramModels);
        if (!nextWord) break;

        sentence += ' ' + nextWord;
        currentText = sentence.split(' ').slice(-2).join(' ');

        if (isSentenceComplete(sentence, corpus)) {
            break;
        }
    }

    return sentence;
}

function TrainModel(corpus) {
    console.log("Training for " + modelname + " has begun.");
    const cleanedCorpus = corpus.trim().replace(/[\r\n]+/g, ' ').replace(/[.,!?]/g, '');
    const nGramModels = buildNGramModels(cleanedCorpus);
    console.log("Training complete!");
    return nGramModels;
}

function GetAIresponse(inputText2) {
    return generateSentence(inputText2, nGramModels, OutputLength);
}

const nGramModels = TrainModel(corpus);

const inputText = 'hi';
const startText = correctText(generateSentence(inputText, nGramModels, OutputLength));

console.log('Generated Sentence:', startText);
