123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237 |
- import requests
- import os
- import signal
- import argparse
- import json
- import pymongo
- from queue import Queue
- from threading import Thread
- from urllib.parse import urlparse
- # S2_API_KEY = os.getenv('S2_API_KEY')
- QUERY_FIELDS1 = 'paperId,corpusId,title,authors,year,url,tldr,venue,externalIds,fieldsOfStudy,s2FieldsOfStudy,abstract,citationCount,referenceCount,publicationTypes,influentialCitationCount,publicationDate,journal'
- QUERY_FIELDS2 = 'paperId,corpusId,title,authors,year,url,venue,externalIds,fieldsOfStudy,s2FieldsOfStudy,abstract,citationCount,referenceCount,publicationTypes,influentialCitationCount,publicationDate,journal'
- QUERY_FIELDS3 = 'paperId,corpusId,title,authors'
- # 读取配置文件中的数据库参数
- with open("config.json", "r") as f:
- config = json.load(f)
- db_url = config["db_url"]
- db_name = config["db_name"]
- db_collection = config["db_collection"]
- NUM_THREADS = config["num_threads"]
- TASK_QUEUE_LEN = config["task_queue_len"]
- S2_API_KEY = config["s2_api_key"]
- # 定义退出标志
- quit_flag = False
- # 连接数据库,创建 papers 集合
- client = pymongo.MongoClient(db_url)
- db = client[db_name]
- papers = db[db_collection]
- papers_data = db['{}_data'.format(db_collection)]
- def read_file(filename):
- data_list = []
- with open(filename, 'r') as f:
- for line in f:
- line_dict = json.loads(line)
- data_list.append(line_dict)
- # 在这里可以对每个字典对象进行操作,例如:
- # print(data_dict['key'])
- return data_list
- def add_paper(file_path):
- papers.create_index("corpusid", unique=True)
- # 读取 paper 文件,存入数据库
- data_list = read_file(file_path)
- # 批量插入数据
- inserted_ids = 0
- try:
- sub_list = []
- for line in data_list:
- sub_list.append(line)
- if len(sub_list) == 2000:
- result = papers.insert_many(sub_list, ordered=False)
- inserted_ids += len(result.inserted_ids)
- sub_list = []
- if sub_list:
- result = papers.insert_many(sub_list, ordered=False)
- inserted_ids += len(result.inserted_ids)
- sub_list = []
- print('-------process', inserted_ids, '/', len(data_list))
- except pymongo.errors.BulkWriteError as e:
- inserted_ids = e.details['nInserted']
- finally:
- # 输出插入结果
- print("总插入数据: {0}, 已插入数据: {1}, 已存在数据: {2}" .format(
- len(data_list), inserted_ids, papers.count_documents({})))
- def crawl_data():
- papers_data.create_index("corpusId", unique=True)
- # 创建任务队列和线程
- q = Queue(TASK_QUEUE_LEN)
- # num_threads = 4
- threads = []
- for i in range(NUM_THREADS):
- t = Thread(target=worker, args=(q,))
- t.daemon = True
- t.start()
- print("starting worker: {}".format(t.native_id))
- threads.append(t)
- # 从数据库中读取 URL,加入任务队列
- for data in papers.find({'$or': [{'consumed': {'$exists': False}}, {'consumed': False}]}):
- if quit_flag:
- break
- if 'consumed' in data and data['consumed']:
- print(data['corpusid'], "already inserted")
- continue
- print('add {} to the task queue'.format(data['corpusid']))
- q.put((data['url'], data['corpusid']))
- #
- print("Waitting for the task queue to complete...")
- q.join()
- print("The task queue has been completed!")
- # 停止线程
- for i in range(NUM_THREADS):
- q.put(None)
- for t in threads:
- print("stoping worker: {}" . format(t.native_id))
- t.join()
- def mark_data_as_consumed(corpus_id):
- result = papers.update_one({'corpusid': corpus_id}, {
- '$set': {'consumed': True}})
- def worker(q):
- while True:
- item = q.get()
- if item is None:
- break
- url = urlparse(item[0]).path
- paper_id = url.split('/')[-1]
- corpus_id = item[1]
- print('crawling {} data: {}'.format(corpus_id, url))
- try:
- data = fetch_data(paper_id)
- if data is not None:
- # papers_data.insert_one(data)
- filter = {'corpusId': corpus_id}
- update = {'$set': data}
- result = papers_data.update_one(filter, update, upsert=True)
- mark_data_as_consumed(corpus_id)
- print(result.upserted_id, "inserted successfully")
- except Exception as error:
- # handle the exception
- print("An exception occurred:", error)
- finally:
- q.task_done()
- def get_paper(paper_id):
- rsp = requests.get(f'https://api.semanticscholar.org/graph/v1/paper/{paper_id}',
- headers={'x-api-key': S2_API_KEY},
- params={'fields': QUERY_FIELDS1})
- rsp.raise_for_status()
- return rsp.json()
- def get_citations(paper_id):
- edges = get_citation_edges(url=f'https://api.semanticscholar.org/graph/v1/paper/{paper_id}/citations',
- headers={'x-api-key': S2_API_KEY},
- params={'fields': QUERY_FIELDS2})
- return list(edge['citingPaper'] for edge in edges)
- def get_references(paper_id):
- edges = get_citation_edges(url=f'https://api.semanticscholar.org/graph/v1/paper/{paper_id}/references',
- headers={'x-api-key': S2_API_KEY},
- params={'fields': QUERY_FIELDS2})
- return list(edge['citedPaper'] for edge in edges)
- # 接口存在人机验证
- def get_related_papers(paper_id):
- rsp = requests.get(url=f'https://www.semanticscholar.org/api/1/paper/{paper_id}/related-papers?limit=10&recommenderType=relatedPapers',
- headers={'x-api-key': S2_API_KEY},
- params={'fields': QUERY_FIELDS3})
- rsp.raise_for_status()
- return rsp.json()['papers']
- def get_recommended_papers(paper_id):
- rsp = requests.get(url=f'https://api.semanticscholar.org/recommendations/v1/papers/forpaper/{paper_id}',
- headers={'x-api-key': S2_API_KEY},
- params={'fields': QUERY_FIELDS2})
- rsp.raise_for_status()
- return rsp.json()['recommendedPapers']
- def get_citation_edges(**req_kwargs):
- """This helps with API endpoints that involve paging."""
- page_size = 1000
- offset = 0
- while True:
- req_kwargs.setdefault('params', dict())
- req_kwargs['params']['limit'] = page_size
- req_kwargs['params']['offset'] = offset
- rsp = requests.get(**req_kwargs)
- rsp.raise_for_status()
- page = rsp.json()["data"]
- for element in page:
- yield element
- if len(page) < page_size:
- break # no more pages
- offset += page_size
- def fetch_data(paper_id):
- print("fetching data:", paper_id)
- data = get_paper(paper_id)
- # print(paper)
- data['citations'] = get_citations(paper_id)
- data['references'] = get_references(paper_id)
- data['recommendedPapers'] = get_recommended_papers(paper_id)
- return data if isinstance(data, dict) else None
- def onSigInt(signo, frame):
- global quit_flag
- quit_flag = True
- print('Ctrl C: Waiting for the process to exit...')
- if __name__ == "__main__":
- # 主进程退出信号
- signal.signal(signal.SIGINT, onSigInt)
- parser = argparse.ArgumentParser(description="Crawl data from URLs")
- parser.add_argument(
- "command", choices=["add_paper", "crawl_data"], help="Command to execute"
- )
- parser.add_argument("--path", help="Path to add to papers")
- args = parser.parse_args()
- if args.command == "add_paper":
- add_paper(args.path)
- elif args.command == "crawl_data":
- crawl_data()
|