Explorar el Código

config.json: add num_threads、task_queue_len

Ben hace 1 año
padre
commit
5e34cceace
Se han modificado 2 ficheros con 22 adiciones y 10 borrados
  1. 4 1
      config.json
  2. 18 9
      spider.py

+ 4 - 1
config.json

@@ -1,5 +1,8 @@
 {
     "db_url": "mongodb://localhost:27017/",
     "db_name": "paper_spider",
-    "db_collection": "papers"
+    "db_collection": "papers",
+    "s2_api_key": "b4YUQrO6w07Zyx9LN8V3p5Lg0WrrGDK520fWJfYd",
+    "num_threads": 10,
+    "task_queue_len": 10
 }

+ 18 - 9
spider.py

@@ -9,7 +9,6 @@ from threading import Thread
 from urllib.parse import urlparse
 
 # S2_API_KEY = os.getenv('S2_API_KEY')
-S2_API_KEY = 'b4YUQrO6w07Zyx9LN8V3p5Lg0WrrGDK520fWJfYd'
 QUERY_FIELDS1 = 'paperId,corpusId,title,authors,year,url,tldr,venue,externalIds,fieldsOfStudy,s2FieldsOfStudy,abstract,citationCount,referenceCount,publicationTypes,influentialCitationCount,publicationDate,journal'
 QUERY_FIELDS2 = 'paperId,corpusId,title,authors,year,url,venue,externalIds,fieldsOfStudy,s2FieldsOfStudy,abstract,citationCount,referenceCount,publicationTypes,influentialCitationCount,publicationDate,journal'
 QUERY_FIELDS3 = 'paperId,corpusId,title,authors'
@@ -21,6 +20,13 @@ db_url = config["db_url"]
 db_name = config["db_name"]
 db_collection = config["db_collection"]
 
+NUM_THREADS = config["num_threads"]
+TASK_QUEUE_LEN = config["task_queue_len"]
+S2_API_KEY = config["s2_api_key"]
+
+# 定义退出标志
+quit_flag = False
+
 # 连接数据库,创建 papers 集合
 client = pymongo.MongoClient(db_url)
 db = client[db_name]
@@ -61,10 +67,10 @@ def crawl_data():
     papers_data.create_index("corpusId", unique=True)
 
     # 创建任务队列和线程
-    q = Queue()
-    num_threads = 4
+    q = Queue(TASK_QUEUE_LEN)
+    # num_threads = 4
     threads = []
-    for i in range(num_threads):
+    for i in range(NUM_THREADS):
         t = Thread(target=worker, args=(q,))
         t.daemon = True
         t.start()
@@ -73,6 +79,8 @@ def crawl_data():
 
     # 从数据库中读取 URL,加入任务队列
     for data in papers.find():
+        if quit_flag is True:
+            break
         url = data["url"]
         corpusid = data["corpusid"]
         if 'consumed' in data.keys() and data['consumed'] is True:
@@ -80,9 +88,8 @@ def crawl_data():
             continue
         # print(data['corpusid'])
         # print(data['url'])
-
+        print('add {} to the task queue'.format(corpusid))
         q.put((url, corpusid))
-        break
 
     #
     print("Waitting for the task queue to complete...")
@@ -90,7 +97,7 @@ def crawl_data():
     print("The task queue has been completed!")
 
     # 停止线程
-    for i in range(num_threads):
+    for i in range(NUM_THREADS):
         q.put(None)
     for t in threads:
         print("stoping worker: {}" . format(t.native_id))
@@ -199,12 +206,14 @@ def fetch_data(paper_id):
 
 
 def onSigInt(signo, frame):
-    pass
+    global quit_flag
+    quit_flag = True
+    print('Ctrl C: Waiting for the process to exit...')
 
 
 if __name__ == "__main__":
     # 主进程退出信号
-    # signal.signal(signal.SIGINT, onSigInt)
+    signal.signal(signal.SIGINT, onSigInt)
 
     parser = argparse.ArgumentParser(description="Crawl data from URLs")
     parser.add_argument(