From 2b57f475560136da833b8b531d0162e54dd603fb Mon Sep 17 00:00:00 2001
From: "guorong.zheng" <360996299@qq.com>
Date: Tue, 3 Dec 2024 18:13:44 +0800
Subject: [PATCH] feat:get url download speed

---
 Pipfile        |  1 +
 Pipfile.lock   | 11 ++++++++-
 utils/speed.py | 62 +++++++++++++++++++++++++++++++++++++++++++++++++-
 3 files changed, 72 insertions(+), 2 deletions(-)

diff --git a/Pipfile b/Pipfile
index 2ce52cf..fc355e2 100644
--- a/Pipfile
+++ b/Pipfile
@@ -39,6 +39,7 @@ fake-useragent = "*"
 gunicorn = "*"
 pillow = "*"
 yt-dlp = "*"
+m3u8 = "*"
 
 [requires]
 python_version = "3.13"
diff --git a/Pipfile.lock b/Pipfile.lock
index d4fbb30..0f3771c 100644
--- a/Pipfile.lock
+++ b/Pipfile.lock
@@ -1,7 +1,7 @@
 {
     "_meta": {
         "hash": {
-            "sha256": "607118c241d8006851d09df1f7eaa5299ede1023bfdf43911c258ce406732c3d"
+            "sha256": "c1675cff542fd8928f33bdc97dcb2a78bf11fda60b23d4d38317a5b56eb52e20"
         },
         "pipfile-spec": 6,
         "requires": {
@@ -441,6 +441,15 @@
             "markers": "python_version >= '3.7'",
             "version": "==3.1.4"
         },
+        "m3u8": {
+            "hashes": [
+                "sha256:566d0748739c552dad10f8c87150078de6a0ec25071fa48e6968e96fc6dcba5d",
+                "sha256:7ade990a1667d7a653bcaf9413b16c3eb5cd618982ff46aaff57fe6d9fa9c0fd"
+            ],
+            "index": "aliyun",
+            "markers": "python_version >= '3.7'",
+            "version": "==6.0.0"
+        },
         "markupsafe": {
             "hashes": [
                 "sha256:0bff5e0ae4ef2e1ae4fdf2dfd5b76c75e5c2fa4132d05fc1b0dabcd20c7e28c4",
diff --git a/utils/speed.py b/utils/speed.py
index 143b69d..cb9ac69 100644
--- a/utils/speed.py
+++ b/utils/speed.py
@@ -2,7 +2,9 @@ import asyncio
 import re
 import subprocess
 from time import time
+from urllib.parse import quote
 
+import m3u8
 import yt_dlp
 from aiohttp import ClientSession, TCPConnector
 
@@ -13,6 +15,62 @@ from utils.tools import is_ipv6, remove_cache_info, get_resolution_value, get_lo
 logger = get_logger(constants.log_path)
 
 
+async def get_speed_with_download(url, timeout=config.sort_timeout):
+    """
+    Get the speed of the url with a total timeout
+    """
+    start_time = time()
+    total_size = 0
+    total_time = 0
+    try:
+        async with ClientSession(
+                connector=TCPConnector(ssl=False), trust_env=True
+        ) as session:
+            async with session.get(url, timeout=timeout) as response:
+                async for chunk in response.content.iter_any():
+                    if chunk:
+                        total_size += len(chunk)
+    except Exception as e:
+        pass
+    finally:
+        end_time = time()
+        total_time += end_time - start_time
+    average_speed = (total_size / total_time if total_time > 0 else 0) / 1024
+    return average_speed
+
+
+async def get_speed_m3u8(url, timeout=config.sort_timeout):
+    """
+    Get the speed of the m3u8 url with a total timeout
+    """
+    start_time = time()
+    total_size = 0
+    total_time = 0
+    try:
+        url = quote(url, safe=':/?$&=@')
+        m3u8_obj = m3u8.load(url)
+        async with ClientSession(
+                connector=TCPConnector(ssl=False), trust_env=True
+        ) as session:
+            for segment in m3u8_obj.segments:
+                if time() - start_time > timeout:
+                    break
+                ts_url = segment.absolute_uri
+                async with session.get(ts_url, timeout=timeout) as response:
+                    file_size = 0
+                    async for chunk in response.content.iter_any():
+                        if chunk:
+                            file_size += len(chunk)
+                    end_time = time()
+                    download_time = end_time - start_time
+                    total_size += file_size
+                    total_time += download_time
+    except Exception as e:
+        pass
+    average_speed = (total_size / total_time if total_time > 0 else 0) / 1024
+    return average_speed
+
+
 def get_info_yt_dlp(url, timeout=config.sort_timeout):
     """
     Get the url info by yt_dlp
@@ -54,7 +112,7 @@ async def get_speed_requests(url, timeout=config.sort_timeout, proxy=None):
     Get the speed of the url by requests
     """
     async with ClientSession(
-            connector=TCPConnector(verify_ssl=False), trust_env=True
+            connector=TCPConnector(ssl=False), trust_env=True
     ) as session:
         start = time()
         end = None
@@ -171,6 +229,8 @@ async def get_speed(url, ipv6_proxy=None, callback=None):
             return speed_cache[cache_key][0]
         if ipv6_proxy and url_is_ipv6:
             speed = (0, None)
+        elif '.m3u8' in url:
+            speed = await get_speed_m3u8(url)
         else:
             speed = await get_speed_yt_dlp(url)
         if cache_key and cache_key not in speed_cache: