diff --git a/Pipfile b/Pipfile index 149ef16..b33b3d2 100644 --- a/Pipfile +++ b/Pipfile @@ -19,6 +19,7 @@ aiohttp = ">=3.9.4" bs4 = ">=0.0.2" tqdm = ">=4.66.3" async-timeout = ">=4.0.3" +aiofiles = "*" [requires] python_version = ">=3.11" diff --git a/Pipfile.lock b/Pipfile.lock index 2930a96..476a870 100644 --- a/Pipfile.lock +++ b/Pipfile.lock @@ -1,7 +1,7 @@ { "_meta": { "hash": { - "sha256": "95be8fbb1ed76f8a0bde1e1a098159813b8b736dbbcbb907c9586a1f9b142ce8" + "sha256": "6a516f7931002ffbd69f157c8b0c5ddc3066399b89b95cc2f6db555fdf40ddc1" }, "pipfile-spec": 6, "requires": { @@ -16,6 +16,15 @@ ] }, "default": { + "aiofiles": { + "hashes": [ + "sha256:19297512c647d4b27a2cf7c34caa7e405c0d60b5560618a29a9fe027b18b0107", + "sha256:84ec2218d8419404abcb9f0c02df3f34c6e0a68ed41072acfb1cef5cbc29051a" + ], + "index": "pypi", + "markers": "python_version >= '3.7'", + "version": "==23.2.1" + }, "aiohttp": { "hashes": [ "sha256:0605cc2c0088fcaae79f01c913a38611ad09ba68ff482402d3410bf59039bfb8", @@ -548,21 +557,21 @@ }, "requests": { "hashes": [ - "sha256:58cd2187c01e70e6e26505bca751777aa9f2ee0b7f4300988b709f44e013003f", - "sha256:942c5a758f98d790eaed1a29cb6eefc7ffb0d1cf7af05c3d2791656dbd6ad1e1" - ], - "index": "pypi", - "markers": "python_version >= '3.7'", - "version": "==2.31.0" - }, - "selenium": { - "hashes": [ - "sha256:5b4f49240d61e687a73f7968ae2517d403882aae3550eae2a229c745e619f1d9", - "sha256:d9dfd6d0b021d71d0a48b865fe7746490ba82b81e9c87b212360006629eb1853" + "sha256:55365417734eb18255590a9ff9eb97e9e1da868d4ccd6402399eaf68af20a760", + "sha256:70761cfe03c773ceb22aa2f671b4757976145175cdfca038c02654d061d6dcc6" ], "index": "pypi", "markers": "python_version >= '3.8'", - "version": "==4.19.0" + "version": "==2.32.3" + }, + "selenium": { + "hashes": [ + "sha256:4770ffe5a5264e609de7dc914be6b89987512040d5a8efb2abb181330d097993", + "sha256:650dbfa5159895ff00ad16e5ddb6ceecb86b90c7ed2012b3f041f64e6e4904fe" + ], + "index": "pypi", + "markers": "python_version >= '3.8'", + "version": "==4.21.0" }, "selenium-stealth": { "hashes": [ @@ -612,11 +621,11 @@ }, "trio": { "hashes": [ - "sha256:9b41f5993ad2c0e5f62d0acca320ec657fdb6b2a2c22b8c7aed6caf154475c4e", - "sha256:e6458efe29cc543e557a91e614e2b51710eba2961669329ce9c862d50c6e8e81" + "sha256:9f5314f014ea3af489e77b001861c535005c3858d38ec46b6b071ebfa339d7fb", + "sha256:e42617ba091e7b2e50c899052e83a3c403101841de925187f61e7b7eaebdf3fb" ], "markers": "python_version >= '3.8'", - "version": "==0.25.0" + "version": "==0.25.1" }, "trio-websocket": { "hashes": [ @@ -628,11 +637,11 @@ }, "typing-extensions": { "hashes": [ - "sha256:83f085bd5ca59c80295fc2a82ab5dac679cbe02b9f33f7d83af68e241bea51b0", - "sha256:c1f94d72897edaf4ce775bb7558d5b79d8126906a14ea5ed1635921406c0387a" + "sha256:8cbcdc8606ebcb0d95453ad7dc5065e6237b6aa230a31e81d0f440c30fed5fd8", + "sha256:b349c66bea9016ac22978d800cfff206d5f9816951f12a7d0ec5578b0a819594" ], "markers": "python_version >= '3.8'", - "version": "==4.11.0" + "version": "==4.12.0" }, "urllib3": { "extras": [ diff --git a/main.py b/main.py index 072ec40..6839e31 100644 --- a/main.py +++ b/main.py @@ -23,8 +23,6 @@ import logging from logging.handlers import RotatingFileHandler import os from tqdm import tqdm -import threading -from queue import Queue handler = RotatingFileHandler("result_new.log", encoding="utf-8") logging.basicConfig( @@ -49,21 +47,16 @@ class UpdateSource: def __init__(self): self.driver = self.setup_driver() - self.stop_event = threading.Event() self.tasks = [] self.results = {} - self.channel_queue = Queue() - self.lock = asyncio.Lock() + self.channel_queue = asyncio.Queue() + self.pbar = None + self.total = 0 async def process_channel(self): - while True: - cate, name, old_urls = self.channel_queue.get() + while not self.channel_queue.empty(): + cate, name, old_urls = await self.channel_queue.get() channel_urls = [] - # pbar.set_description( - # f"Processing {name}, {total_channels - pbar.n} channels remaining" - # ) - # if pbar.n == 0: - # self.update_progress(f"正在处理频道: {name}", 0) format_name = formatChannelName(name) info_list = [] if config.open_subscribe: @@ -88,8 +81,7 @@ class UpdateSource: if ( config.open_sort and not github_actions - or github_actions == "true" - # or (pbar.n <= 200 and github_actions == "true") + or (self.pbar.n <= 200 and github_actions == "true") ): sorted_data = await sortUrlsBySpeedAndResolution(info_list) if sorted_data: @@ -104,28 +96,23 @@ class UpdateSource: ) if len(channel_urls) == 0: channel_urls = filterUrlsByPatterns(old_urls) - except Exception as e: - print(e) - # finally: - # pbar.update() - # self.update_progress( - # f"正在处理频道: {name}", int((pbar.n / total_channels) * 100) - # ) - await updateChannelUrlsTxt(self.lock, cate, name, channel_urls) + except: + pass + await updateChannelUrlsTxt(cate, name, channel_urls) self.channel_queue.task_done() + async def run_task(self, task, pbar): + result = await task + pbar.update() + self.update_progress(f"正在更新...", int((self.pbar.n / self.total) * 100)) + return result + async def visitPage(self, channel_items): - # channel_names = [ - # name - # for _, channel_obj in channel_items.items() - # for name in channel_obj.keys() - # ] task_dict = { "open_subscribe": getChannelsBySubscribeUrls, "open_multicast": getChannelsByFOFA, "open_online_search": useAccessibleUrl, } - tasks = [] for config_name, task_func in task_dict.items(): if getattr(config, config_name): task = None @@ -136,34 +123,38 @@ class UpdateSource: task_func(self.driver, self.update_progress) ) else: - task = asyncio.create_task(task_func) - tasks.append(task) - task_results = await asyncio.gather(*tasks) + task = asyncio.create_task(task_func()) + if task: + self.tasks.append(task) + task_results = await asyncio.gather(*self.tasks) + self.tasks = [] for i, config_name in enumerate( [name for name in task_dict if getattr(config, name)] ): self.results[config_name] = task_results[i] - # total_channels = len(channel_names) - # pbar = tqdm(total=total_channels) for cate, channel_obj in channel_items.items(): channel_obj_keys = channel_obj.keys() for name in channel_obj_keys: - self.channel_queue.put((cate, name, channel_obj[name])) - # pbar.close() + await self.channel_queue.put((cate, name, channel_obj[name])) async def main(self): try: - task = asyncio.create_task(self.visitPage(getChannelItems())) - self.tasks.append(task) - await task + await self.visitPage(getChannelItems()) for _ in range(10): - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - channel_thread = threading.Thread( - target=loop.run_until_complete, args=(self.process_channel(),) - ) - channel_thread.start() - self.channel_queue.join() + channel_task = asyncio.create_task(self.process_channel()) + self.tasks.append(channel_task) + self.total = self.channel_queue.qsize() + self.pbar = tqdm(total=self.channel_queue.qsize()) + self.pbar.set_description( + f"Processing..., {self.channel_queue.qsize()} channels remaining" + ) + self.update_progress(f"正在更新...", int((self.pbar.n / self.total) * 100)) + tasks_with_progress = [ + self.run_task(task, self.pbar) for task in self.tasks + ] + await asyncio.gather(*tasks_with_progress) + self.tasks = [] + self.pbar.close() for handler in logging.root.handlers[:]: handler.close() logging.root.removeHandler(handler) @@ -180,13 +171,8 @@ class UpdateSource: def start(self, callback): self.update_progress = callback - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - thread = threading.Thread(target=loop.run_until_complete, args=(self.main(),)) - thread.start() + asyncio.run(self.main()) def stop(self): - self.stop_event.set() for task in self.tasks: task.cancel() - self.stop_event.clear() diff --git a/utils.py b/utils.py index 16ae34e..8013902 100644 --- a/utils.py +++ b/utils.py @@ -4,6 +4,7 @@ except ImportError: import config import aiohttp import asyncio +import aiofiles import time import re import datetime @@ -21,8 +22,6 @@ from tqdm import tqdm from selenium.webdriver.common.by import By from selenium.webdriver.support.ui import WebDriverWait from selenium.webdriver.support import expected_conditions as EC -from queue import Queue -import threading def formatChannelName(name): @@ -109,44 +108,60 @@ async def getChannelsBySubscribeUrls(callback): pattern = r"^(.*?),(?!#genre#)(.*?)$" subscribe_urls_len = len(config.subscribe_urls) pbar = tqdm(total=subscribe_urls_len) + queue = asyncio.Queue() for base_url in config.subscribe_urls: - try: - pbar.set_description( - f"Processing subscribe {base_url}, {subscribe_urls_len - pbar.n} urls remaining" - ) - if pbar.n == 0: - callback(f"正在获取订阅源", 0) - try: - response = requests.get(base_url, timeout=30) - except requests.exceptions.Timeout: - print(f"Timeout on {base_url}") - continue - content = response.text - if content: - lines = content.split("\n") - for line in lines: - if re.match(pattern, line) is not None: - key = re.match(pattern, line).group(1) - resolution_match = re.search(r"_(\((.*?)\))", key) - resolution = ( - resolution_match.group(2) - if resolution_match is not None - else None - ) - key = formatChannelName(key) - url = re.match(pattern, line).group(2) - value = (url, None, resolution) - if key in channels: - if value not in channels[key]: - channels[key].append(value) - else: - channels[key] = [value] - except Exception as e: - print(f"Error on {base_url}: {e}") - continue - finally: - pbar.update() - callback(f"正在获取订阅源", int((pbar.n / subscribe_urls_len) * 100)) + await queue.put(base_url) + + async def processSubscribeChannels(): + while not queue.empty(): + base_url = await queue.get() + if base_url: + try: + pbar.set_description( + f"Processing subscribe {base_url}, {subscribe_urls_len - pbar.n} urls remaining" + ) + if pbar.n == 0: + callback(f"正在获取订阅源", 0) + try: + response = requests.get(base_url, timeout=30) + except requests.exceptions.Timeout: + print(f"Timeout on {base_url}") + continue + content = response.text + if content: + lines = content.split("\n") + for line in lines: + if re.match(pattern, line) is not None: + key = re.match(pattern, line).group(1) + resolution_match = re.search(r"_(\((.*?)\))", key) + resolution = ( + resolution_match.group(2) + if resolution_match is not None + else None + ) + key = formatChannelName(key) + url = re.match(pattern, line).group(2) + value = (url, None, resolution) + if key in channels: + if value not in channels[key]: + channels[key].append(value) + else: + channels[key] = [value] + except Exception as e: + print(f"Error on {base_url}: {e}") + continue + finally: + pbar.update() + callback( + f"正在获取订阅源", int((pbar.n / subscribe_urls_len) * 100) + ) + queue.task_done() + + tasks = [] + for _ in range(10): + task = asyncio.create_task(processSubscribeChannels()) + tasks.append(task) + await asyncio.gather(*tasks) print("Finished processing subscribe urls") pbar.close() return channels @@ -201,20 +216,19 @@ def getChannelsInfoListByOnlineSearch(driver, pageUrl, name): return info_list -async def updateChannelUrlsTxt(lock, cate, name, urls): +async def updateChannelUrlsTxt(cate, name, urls): """ Update the category and channel urls to the final file """ - async with lock: - try: - with open("result_new.txt", "a", encoding="utf-8") as f: - f.write(cate + ",#genre#\n") - for url in urls: - if url is not None: - f.write(name + "," + url + "\n") - f.write("\n") - finally: - f.close + try: + async with aiofiles.open("result_new.txt", "a", encoding="utf-8") as f: + await f.write(cate + ",#genre#\n") + for url in urls: + if url is not None: + await f.write(name + "," + url + "\n") + await f.write("\n") + finally: + f.close def updateFile(final_file, old_file): @@ -461,11 +475,12 @@ async def useAccessibleUrl(): """ baseUrl1 = "https://www.foodieguide.com/iptvsearch/" baseUrl2 = "http://tonkiang.us/" - speed1 = await getSpeed(baseUrl1, 30) - speed2 = await getSpeed(baseUrl2, 30) - if speed1 == float("inf") and speed2 == float("inf"): + task1 = asyncio.create_task(getSpeed(baseUrl1, 30)) + task2 = asyncio.create_task(getSpeed(baseUrl2, 30)) + task_results = await asyncio.gather(task1, task2) + if task_results[0] == float("inf") and task_results[1] == float("inf"): return None - if speed1 < speed2: + if task_results[0] < task_results[1]: return baseUrl1 else: return baseUrl2 @@ -487,69 +502,6 @@ def getFOFAUrlsFromRegionList(): return urls -fofa_results = {} -fofa_queue = Queue() - - -async def processFOFAChannels(pbar, fofa_urls_len, driver, callback): - while True: - fofa_url = fofa_queue.get() - if fofa_url: - try: - pbar.set_description( - f"Processing multicast {fofa_url}, {fofa_urls_len - pbar.n} urls remaining" - ) - if pbar.n == 0: - callback(f"正在获取组播源", 0) - driver.get(fofa_url) - await asyncio.sleep(10) - fofa_source = re.sub( - r"<!--.*?-->", "", driver.page_source, flags=re.DOTALL - ) - urls = set(re.findall(r"https?://[\w\.-]+:\d+", fofa_source)) - channels = {} - for url in urls: - try: - response = requests.get( - url + "/iptv/live/1000.json?key=txiptv", timeout=2 - ) - try: - json_data = response.json() - if json_data["code"] == 0: - try: - for item in json_data["data"]: - if isinstance(item, dict): - item_name = formatChannelName( - item.get("name") - ) - item_url = item.get("url").strip() - if item_name and item_url: - total_url = url + item_url - if item_name not in channels: - channels[item_name] = [total_url] - else: - channels[item_name].append( - total_url - ) - except Exception as e: - # print(f"Error on fofa: {e}") - continue - except Exception as e: - # print(f"{url}: {e}") - continue - except Exception as e: - # print(f"{url}: {e}") - continue - mergeObjects(fofa_results, channels) - fofa_queue.task_done() - except Exception as e: - print(e) - # continue - finally: - pbar.update() - callback(f"正在获取组播源", int((pbar.n / fofa_urls_len) * 100)) - - async def getChannelsByFOFA(driver, callback): """ Get the channel by FOFA @@ -557,18 +509,78 @@ async def getChannelsByFOFA(driver, callback): fofa_urls = getFOFAUrlsFromRegionList() fofa_urls_len = len(fofa_urls) pbar = tqdm(total=fofa_urls_len) + fofa_results = {} + fofa_queue = asyncio.Queue() for fofa_url in fofa_urls: - fofa_queue.put(fofa_url) + await fofa_queue.put(fofa_url) + + async def processFOFAChannels(pbar, fofa_urls_len, driver, callback): + while not fofa_queue.empty(): + fofa_url = await fofa_queue.get() + if fofa_url: + try: + pbar.set_description( + f"Processing multicast {fofa_url}, {fofa_urls_len - pbar.n} urls remaining" + ) + if pbar.n == 0: + callback(f"正在获取组播源", 0) + driver.get(fofa_url) + await asyncio.sleep(10) + fofa_source = re.sub( + r"<!--.*?-->", "", driver.page_source, flags=re.DOTALL + ) + urls = set(re.findall(r"https?://[\w\.-]+:\d+", fofa_source)) + channels = {} + for url in urls: + try: + response = requests.get( + url + "/iptv/live/1000.json?key=txiptv", timeout=2 + ) + try: + json_data = response.json() + if json_data["code"] == 0: + try: + for item in json_data["data"]: + if isinstance(item, dict): + item_name = formatChannelName( + item.get("name") + ) + item_url = item.get("url").strip() + if item_name and item_url: + total_url = url + item_url + if item_name not in channels: + channels[item_name] = [ + total_url + ] + else: + channels[item_name].append( + total_url + ) + except Exception as e: + # print(f"Error on fofa: {e}") + continue + except Exception as e: + # print(f"{url}: {e}") + continue + except Exception as e: + # print(f"{url}: {e}") + continue + mergeObjects(fofa_results, channels) + except Exception as e: + print(e) + # continue + finally: + pbar.update() + callback(f"正在获取组播源", int((pbar.n / fofa_urls_len) * 100)) + fofa_queue.task_done() + + tasks = [] for _ in range(10): - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - channel_thread = threading.Thread( - target=loop.run_until_complete, - args=(processFOFAChannels(pbar, fofa_urls_len, driver, callback),), - daemon=True, + task = asyncio.create_task( + processFOFAChannels(pbar, fofa_urls_len, driver, callback) ) - channel_thread.start() - fofa_queue.join() + tasks.append(task) + await asyncio.gather(*tasks) pbar.close() return fofa_results