LDA主题挖掘并通过一致性分数和困惑度进行验证

if __name__ == '__main__':
    from gensim.models import CoherenceModel
    from gensim.corpora.dictionary import Dictionary
    from gensim.models.ldamodel import LdaModel
    import pandas as pd
    import nltk
    from nltk.tokenize import word_tokenize
    import matplotlib.pyplot as plt
    import math

    nltk.download('stopwords')

    data = pd.read_excel("Laos news overall.xlsx")
    titles = data['标题'].tolist()

    stop_words = set(nltk.corpus.stopwords.words('english'))
    tokenizer = nltk.RegexpTokenizer(r'\w+')

    tokenized_titles = []
    for title in titles:
        words = tokenizer.tokenize(title)
        words = [word for word in words if word.lower() not in stop_words]
        tokenized_titles.append(words)

    start_topics = 1
    end_topics = 20
    num_runs = 10  # Number of times to run the LDA and plot

    for _ in range(num_runs):
        coherence_scores = []
        perplexity_scores = []

        for num_topics in range(start_topics, end_topics + 1):
            dictionary = Dictionary(tokenized_titles)
            corpus = []

            for title in tokenized_titles:
                doc = dictionary.doc2bow(title)
                corpus.append(doc)

            lda = LdaModel(corpus, num_topics=num_topics, id2word=dictionary)

            coherence_model = CoherenceModel(model=lda, texts=tokenized_titles, dictionary=dictionary, coherence='c_v')
            coherence_score = coherence_model.get_coherence()
            coherence_scores.append(coherence_score)

            perplexity = math.exp(lda.log_perplexity(corpus))
            perplexity_scores.append(perplexity)

        plt.plot(range(start_topics, end_topics + 1), coherence_scores, marker='o', label='Coherence Score (Run {})'.format(_ + 1))
        plt.plot(range(start_topics, end_topics + 1), perplexity_scores, marker='x', label='Perplexity (Run {})'.format(_ + 1))

    plt.xlabel("Number of Topics")
    plt.ylabel("Score")
    plt.title("Coherence Score and Perplexity vs. Number of Topics")
    plt.xticks(range(start_topics, end_topics + 1))
    plt.legend()
    plt.grid(True)
    plt.show()

你可能感兴趣的:(python)