import requests import os import signal import argparse import json import pymongo from queue import Queue from threading import Thread from urllib.parse import urlparse # S2_API_KEY = os.getenv('S2_API_KEY') QUERY_FIELDS1 = 'paperId,corpusId,title,authors,year,url,tldr,venue,externalIds,fieldsOfStudy,s2FieldsOfStudy,abstract,citationCount,referenceCount,publicationTypes,influentialCitationCount,publicationDate,journal' QUERY_FIELDS2 = 'paperId,corpusId,title,authors,year,url,venue,externalIds,fieldsOfStudy,s2FieldsOfStudy,abstract,citationCount,referenceCount,publicationTypes,influentialCitationCount,publicationDate,journal' QUERY_FIELDS3 = 'paperId,corpusId,title,authors' # 读取配置文件中的数据库参数 with open("config.json", "r") as f: config = json.load(f) db_url = config["db_url"] db_name = config["db_name"] db_collection = config["db_collection"] NUM_THREADS = config["num_threads"] TASK_QUEUE_LEN = config["task_queue_len"] S2_API_KEY = config["s2_api_key"] # 定义退出标志 quit_flag = False # 连接数据库,创建 papers 集合 client = pymongo.MongoClient(db_url) db = client[db_name] papers = db[db_collection] papers_data = db['{}_data'.format(db_collection)] def read_file(filename): data_list = [] with open(filename, 'r') as f: for line in f: line_dict = json.loads(line) data_list.append(line_dict) # 在这里可以对每个字典对象进行操作,例如: # print(data_dict['key']) return data_list def add_paper(file_path): papers.create_index("corpusid", unique=True) # 读取 paper 文件,存入数据库 # data_list = read_file(file_path) # 批量插入数据 inserted_ids = 0 try: sub_list = [] with open(file_path, 'r') as f: for line in f: line_dict = json.loads(line) sub_list.append(line_dict) if len(sub_list) == 2000: result = papers.insert_many(sub_list, ordered=False) inserted_ids += len(result.inserted_ids) sub_list = [] print('-------process', inserted_ids, '/', '7318795') if sub_list: result = papers.insert_many(sub_list, ordered=False) inserted_ids += len(result.inserted_ids) sub_list = [] except pymongo.errors.BulkWriteError as e: inserted_ids = e.details['nInserted'] finally: # 输出插入结果 print("总插入数据: {0}, 已插入数据: {1}, 已存在数据: {2}" .format( 7318795, inserted_ids, papers.count_documents({}))) def crawl_data(): # papers_data.create_index("corpusId", unique=True) # 创建任务队列和线程 q = Queue(TASK_QUEUE_LEN) # num_threads = 4 threads = [] for i in range(NUM_THREADS): t = Thread(target=worker, args=(q,)) t.daemon = True t.start() print("starting worker: {}".format(t.native_id)) threads.append(t) # 从数据库中读取 URL,加入任务队列 while True: try: for data in papers.find({'$or': [{'consumed': {'$exists': False}}, {'consumed': False}]}): if quit_flag: break if 'consumed' in data and data['consumed']: print(data['corpusid'], "already inserted") continue print('add {} to the task queue'.format(data['corpusid'])) q.put((data['url'], data['corpusid'])) break except Exception as e: print('crawl_data error', e) continue # print("Waitting for the task queue to complete...") q.join() print("The task queue has been completed!") # 停止线程 for i in range(NUM_THREADS): q.put(None) for t in threads: print("stoping worker: {}" . format(t.native_id)) t.join() def mark_data_as_consumed(corpus_id): result = papers.update_one({'corpusid': corpus_id}, { '$set': {'consumed': True}}) def worker(q): while True: item = q.get() if item is None: break url = urlparse(item[0]).path paper_id = url.split('/')[-1] corpus_id = item[1] print('crawling {} data: {}'.format(corpus_id, url)) try: data = fetch_data(paper_id) if data is not None: # papers_data.insert_one(data) filter = {'corpusId': corpus_id} update = {'$set': data} result = papers_data.update_one(filter, update, upsert=True) mark_data_as_consumed(corpus_id) print(result.upserted_id, "inserted successfully") except Exception as error: # handle the exception print("An exception occurred:", error) finally: q.task_done() def get_paper(paper_id): rsp = requests.get(f'https://api.semanticscholar.org/graph/v1/paper/{paper_id}', headers={'x-api-key': S2_API_KEY}, params={'fields': QUERY_FIELDS1}) rsp.raise_for_status() return rsp.json() def get_citations(paper_id): edges = get_citation_edges(url=f'https://api.semanticscholar.org/graph/v1/paper/{paper_id}/citations', headers={'x-api-key': S2_API_KEY}, params={'fields': QUERY_FIELDS2}) return list(edge['citingPaper'] for edge in edges) def get_references(paper_id): edges = get_citation_edges(url=f'https://api.semanticscholar.org/graph/v1/paper/{paper_id}/references', headers={'x-api-key': S2_API_KEY}, params={'fields': QUERY_FIELDS2}) return list(edge['citedPaper'] for edge in edges) # 接口存在人机验证 def get_related_papers(paper_id): rsp = requests.get(url=f'https://www.semanticscholar.org/api/1/paper/{paper_id}/related-papers?limit=10&recommenderType=relatedPapers', headers={'x-api-key': S2_API_KEY}, params={'fields': QUERY_FIELDS3}) rsp.raise_for_status() return rsp.json()['papers'] def get_recommended_papers(paper_id): rsp = requests.get(url=f'https://api.semanticscholar.org/recommendations/v1/papers/forpaper/{paper_id}', headers={'x-api-key': S2_API_KEY}, params={'fields': QUERY_FIELDS2}) rsp.raise_for_status() return rsp.json()['recommendedPapers'] def get_citation_edges(**req_kwargs): """This helps with API endpoints that involve paging.""" page_size = 1000 offset = 0 while True: req_kwargs.setdefault('params', dict()) req_kwargs['params']['limit'] = page_size req_kwargs['params']['offset'] = offset rsp = requests.get(**req_kwargs) rsp.raise_for_status() page = rsp.json()["data"] for element in page: yield element if len(page) < page_size: break # no more pages offset += page_size def fetch_data(paper_id): print("fetching data:", paper_id) data = get_paper(paper_id) # print(paper) data['citations'] = get_citations(paper_id) data['references'] = get_references(paper_id) data['recommendedPapers'] = get_recommended_papers(paper_id) print('>>> fetch data OK, citations: {0}, references: {1}, recommendedPapers: {2}'.format( len(data.get('citations', [])), len(data.get('references', [])), len(data.get('recommendedPapers', [])) )) return data if isinstance(data, dict) else None def onSigInt(signo, frame): global quit_flag quit_flag = True print('Ctrl C: Waiting for the process to exit...') if __name__ == "__main__": # 主进程退出信号 signal.signal(signal.SIGINT, onSigInt) parser = argparse.ArgumentParser(description="Crawl data from URLs") parser.add_argument( "command", choices=["add_paper", "crawl_data"], help="Command to execute" ) parser.add_argument("--path", help="Path to add to papers") args = parser.parse_args() if args.command == "add_paper": add_paper(args.path) elif args.command == "crawl_data": crawl_data()