模块1:数据抓取
首先本地要安装mysql,然后修改news_main.py相应mysql链接参数
connect = pymysql.connect(host='localhost', # 本地数据库
user='root',
password='123456',
db='news_collect',
charset='utf8') #服务器名,账户,密码,数据库名称
当然,如果不适用mysql,那么也可以改动相应代码,改成抓取后数据存入到xls表格中,这里不做介绍
news_main.py文件完整内容(数据抓取)
import pymysql
from selenium import webdriver
from selenium.webdriver.chrome.service import Service
from selenium.webdriver.support import expected_conditions as EC
from selenium.webdriver.support.ui import WebDriverWait
from selenium.webdriver.common.by import By
import hashlib
connect = pymysql.connect(host='localhost', # 本地数据库
user='root',
password='123456',
db='news_collect',
charset='utf8') #服务器名,账户,密码,数据库名称
cur = connect.cursor()
driver = webdriver.Chrome()
#加载栏目
industrys = []
try:
driver.get('http://www.chinabgao.com/')
element = WebDriverWait(driver,100).until(
EC.presence_of_element_located((By.ID,"cateitems"))
)
cates = driver.find_elements(by=By.XPATH, value="//li[contains(@class, 'item')]/h3/a")
for cate in cates:
cate_href = cate.get_attribute("href")
industrys.append(cate_href.split("/")[-2])
except:
print("栏目加载失败")
#创建mysql数据库
for table in industrys:
try:
create_table = "CREATE TABLE `%s` ( \
`id` bigint(20) NOT NULL AUTO_INCREMENT,\
`title` varchar(255) DEFAULT NULL,\
`content` varchar(255) DEFAULT NULL,\
`hash` varchar(255) DEFAULT NULL,\
PRIMARY KEY (`id`),\
UNIQUE KEY `d` (`hash`) USING HASH,\
`date1` date DEFAULT NULL, \
`date2` timestamp NULL DEFAULT NULL\
) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4;" % (table)
cur.execute(create_table)
except:
print("数据库table:%s已存在" % (table))
connect.commit()
def visit_detail():
cur_url = driver.current_url
views = driver.find_elements(by=By.XPATH, value="//div[contains(@class, 'listtitle')]/a")
link_l = []
for view in views:
link_l.append(view.get_attribute("href"))
ret = []
for view_href in link_l:
driver.get(view_href)
try:
element = WebDriverWait(driver,100).until(
EC.presence_of_element_located((By.CLASS_NAME,"arctitle"))
)
pub_time = driver.find_element(by=By.XPATH, value="//span[contains(@class, 'pubTime')]")
print(pub_time.text)
arc_dec = driver.find_element(by=By.XPATH, value="//div[contains(@class, 'arcdec')]")
print(arc_dec.text)
ret.append((pub_time.text[:10], pub_time.text, arc_dec.text))
except:
ret.append(("", "", ""))
driver.get(cur_url)
element = WebDriverWait(driver,100).until(
EC.presence_of_element_located((By.CLASS_NAME,"listcon"))
)
return ret
i = 0
print("当前插入栏目:", 'http://www.chinabgao.com/info/%s'%(industrys[i]))
driver.get('http://www.chinabgao.com/info/%s'%(industrys[i]))
try:
while i < len(industrys):
element = WebDriverWait(driver,100).until(
EC.presence_of_element_located((By.CLASS_NAME,"listcon"))
)
list_title = []
contents = driver.find_elements(by=By.CLASS_NAME, value='listtitle')
for content in contents:
list_title.append(content.text)
#print(content.text)
list_content = []
contents = driver.find_elements(by=By.CLASS_NAME, value='preview')
for content in contents:
list_content.append(content.text)
#print(content.text)
next_href = ""
try:
next_link = driver.find_element(by=By.XPATH, value="//span[contains(@class, 'pagebox_next')]/a[text()='下一页']")
next_href = next_link.get_attribute("href")
except:
pass
detail = visit_detail()
err_cnt = 0
for j in range(len(list_title)):
try:
if detail[j][0] == "":
next_href = ""
break
insert_sqli = "insert into `%s` (title, content, hash, date1, date2) values('%s', '%s', '%s', '%s', '%s');" % \
(industrys[i], list_title[j], detail[j][2], hashlib.md5(list_title[j].encode(encoding='UTF-8')).hexdigest(),
detail[j][0], detail[j][1])
cur.execute(insert_sqli)
connect.commit()
except:
#print("插入失败:", insert_sqli)
err_cnt += 1
pass
#如果插入全部失败,则直接跳过
if err_cnt == len(list_title):
i += 1
print("当前插入栏目:", 'http://www.chinabgao.com/info/%s'%(industrys[i]))
driver.get('http://www.chinabgao.com/info/%s'%(industrys[i]))
continue
if next_href != "":
driver.get(next_href)
else:
i += 1
print("当前插入栏目:", 'http://www.chinabgao.com/info/%s'%(industrys[i]))
driver.get('http://www.chinabgao.com/info/%s'%(industrys[i]))
finally:
driver.quit()
# ---------------------关闭数据库
cur.close() # 关闭游标
connect.close() # 关闭数据库连接
模块2:数据分析
import os
import sys
from unittest import result
import pymysql
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from selenium import webdriver
from selenium.webdriver.chrome.service import Service
from selenium.webdriver.support import expected_conditions as EC
from selenium.webdriver.support.ui import WebDriverWait
from selenium.webdriver.common.by import By
import hashlib
industrys = []
connect = pymysql.connect(host='localhost', # 本地数据库
user='root',
password='123456',
db='mysql',
charset='utf8') #服务器名,账户,密码,数据库名称
admin = connect.cursor()
admin.execute("select * from `innodb_table_stats` where database_name = 'news_collect'")
tables = admin.fetchall()
for table in tables:
industrys.append(table[1])
admin.close()
connect = pymysql.connect(host='localhost', # 本地数据库
user='root',
password='123456',
db='news_collect',
charset='utf8') #服务器名,账户,密码,数据库名称
cur = connect.cursor()
print(industrys)
for industry in industrys:
cur.execute("select * from `%s`" % (industry))
result = cur.fetchall()
result = list(result)
result.sort(key=lambda x: x[4], reverse=False)
print(result)
l_date = []
l_date_str = []
for item in result:
l_date.append(item[4])
l_date = list(set(l_date))
l_date.sort(reverse=False)
l_cnt = []
l_tmp = []
for date in l_date:
l_date_str.append(str(date))
for x in result:
print(x)
if date >= x[4] and (date - x[4]).days <= 30:
l_tmp.append(x)
l_cnt.append(len(l_tmp))
l_tmp = []
base_dir = os.path.dirname(os.path.abspath(__file__))+"/analyse/model"
china_bao = base_dir + "/chinabao"
industry_dir = china_bao + "/" + industry
if not os.path.exists(china_bao):
os.mkdir(china_bao)
if not os.path.exists(industry_dir):
os.mkdir(industry_dir)
idx = 0
for x in result:
f = open(industry_dir + "/" + str(idx) + ".txt", 'w',encoding='utf-8')
print(industry, x)
f.write(x[1] + "\n")
f.write(x[2])
idx += 1
f.close()
#保存图片
print(l_date_str)
print(l_cnt)
fig = plt.figure()
fig.suptitle(industry)
plt.plot(l_date_str, l_cnt)
plt.xticks(rotation=90)
fig.savefig(os.path.dirname(os.path.abspath(__file__))+"/analyse/pic/" +industry+".jpg")
#plt.show()
数据分析 会在文件夹pic下生成政策数量30天移动平均线
模块3:自然语言处理之文本分类
# -*- coding:utf-8 -*-
# Author:hankcs
# Date: 2018-05-23 17:26
import os
import sys
sys.path.append('F:/test/nlp/pyhanlp')
print(sys.path)
from pyhanlp import SafeJClass
from test_utility import ensure_data
NaiveBayesClassifier = SafeJClass('com.hankcs.hanlp.classification.classifiers.NaiveBayesClassifier')
IOUtil = SafeJClass('com.hankcs.hanlp.corpus.io.IOUtil')
# sogou_corpus_path = ensure_data('搜狗文本分类语料库迷你版',
# 'http://file.hankcs.com/corpus/sogou-text-classification-corpus-mini.zip')
sogou_corpus_path = os.path.dirname(os.path.abspath(__file__))+"/analyse/model/chinabao"
def train_or_load_classifier():
model_path = sogou_corpus_path + '.ser'
print(model_path)
if os.path.isfile(model_path):
return NaiveBayesClassifier(IOUtil.readObjectFrom(model_path))
print("xxxxxxxxxxxxxx")
classifier = NaiveBayesClassifier()
classifier.train(sogou_corpus_path)
model = classifier.getModel()
IOUtil.saveObjectTo(model, model_path)
return NaiveBayesClassifier(model)
def predict(classifier, text):
print("《%16s》\t属于分类\t【%s】" % (text, classifier.classify(text)))
# 如需获取离散型随机变量的分布,请使用predict接口
# print("《%16s》\t属于分类\t【%s】" % (text, classifier.predict(text)))
if __name__ == '__main__':
classifier = train_or_load_classifier()
predict(classifier, "手机市场")
predict(classifier, "养鸡行业利润暴增")
predict(classifier, "北京家具质检结果一半不合格")
predict(classifier, "业内看好草甘膦供需格局")
predict(classifier, "据宇博智业研究中心了解,借助肥业APP,企业可以抓住移动电商发展的机会,成功的实现企业转型。")
输出分类结果: