diff --git a/.cursor/rules/weclone-rules.mdc b/.cursor/rules/weclone-rules.mdc new file mode 100644 index 0000000000000000000000000000000000000000..80fccee2202e398027bbfa78ae407a79c53339ef --- /dev/null +++ b/.cursor/rules/weclone-rules.mdc @@ -0,0 +1,23 @@ +--- +description: +globs: +alwaysApply: true +--- +--- +description: +globs: +alwaysApply: true +--- + +# Your rule content +- You can @ files here +- The project uses uv as the package manager and pyproject.toml as the project configuration file. +- Unless I ask you to, code comments don't need to be excessive. +- Prefer using the encapsulated logger `from weclone.utils.log import logger` for printing. +- When retrieving values from a parameter dictionary read from a configuration file, the `get` method should be preferred whenever possible. + + + + + + diff --git a/.gitattributes b/.gitattributes index a6344aac8c09253b3b630fb776ae94478aa0275b..62c3c8c54d42d509480ba3ed07b66f72d1500c2f 100644 --- a/.gitattributes +++ b/.gitattributes @@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text +torchvision.whl filter=lfs diff=lfs merge=lfs -text +weclone-audio/src/sample.wav filter=lfs diff=lfs merge=lfs -text diff --git a/.github/issue-labeler.yml b/.github/issue-labeler.yml new file mode 100644 index 0000000000000000000000000000000000000000..22ece25a2589a5146888656741058c38557aeecf --- /dev/null +++ b/.github/issue-labeler.yml @@ -0,0 +1,47 @@ +# 添加 Discussion 标签 +Discussion: + - '(讨论|交流|分享|意见|建议|思考|探讨|交换意见|brainstorm|discussion)' + +# 添加 bug 标签 +bug: + - '(bug|错误|问题|失败|崩溃|异常|报错|不工作|无法运行|broken|crash|error|exception|fails)' + +# 添加 chatbot 标签 +chatbot: + - '(聊天机器人|chatbot|chat bot|对话机器人|聊天助手|AI助手|机器人对话|bot|assistant)' + +# 添加 documentation 标签 +documentation: + - '(文档|说明|使用指南|指导|手册|教程|文档更新|documentation|docs|guide|tutorial|readme)' + +# 添加 duplicate 标签 +duplicate: + - '(重复|已有|duplicate|已经存在|已提交过|重复问题|重复报告|dup)' + +# 添加 feature 标签 +feature: + - '(功能|特性|新增|增加|添加|实现|feature|enhancement|新功能|功能请求|feature request)' + +# 添加 good first issue 标签 +good first issue: + - '(入门|简单|容易|新手|初学者|开始|first|beginner|starter|easy|简单任务|good first issue)' + +# 添加 help wanted 标签 +help wanted: + - '(需要帮助|寻求帮助|请求协助|help|求助|协助|帮忙|help wanted|need help|assistance)' + +# 添加 invalid 标签 +invalid: + - '(无效|不适用|不相关|无关|错误提交|invalid|not relevant|irrelevant|not applicable)' + +# 添加 Mac 标签 +Mac: + - '(Mac|MacOS|macOS|OSX|Mac系统|苹果系统|苹果电脑|MacBook)' + +# 添加 question 标签 +question: + - '(问题|疑问|如何|怎么|请问|是否|能否|可以吗|question|how to|what is|why)' + +# 添加 Windows 标签 +Windows: + - '(Windows|微软|Win10|Win11|Windows系统|微软系统|win)' diff --git a/.github/workflows/issue-labeler.yml b/.github/workflows/issue-labeler.yml new file mode 100644 index 0000000000000000000000000000000000000000..3afb59e903930d69355803485d9de4931e9582a4 --- /dev/null +++ b/.github/workflows/issue-labeler.yml @@ -0,0 +1,30 @@ +name: add labels to Issues + +on: + issues: + types: [opened, edited] + + +jobs: + label_issues: + runs-on: ubuntu-latest + permissions: + issues: write + contents: read + steps: + - name: checkout + uses: actions/checkout@v3 + + - name: get_last_run_time + id: last_run + run: | + # 获取当前日期减去 1 天作为默认值(处理最近一天的 issues) + echo "date=$(date -d '1 day ago' -u +"%Y-%m-%dT%H:%M:%SZ")" >> $GITHUB_OUTPUT + + - name: RegEx Issue Labeler + uses: github/issue-labeler@v3.4 + with: + repo-token: "${{ secrets.GITHUB_TOKEN }}" + configuration-path: .github/issue-labeler.yml + enable-versioned-regex: 0 + not-before: ${{ steps.last_run.outputs.date }} diff --git a/.github/workflows/update_space.yml b/.github/workflows/update_space.yml new file mode 100644 index 0000000000000000000000000000000000000000..76b1e8dd220c2314a6acf96ef7fdd1cd364c144d --- /dev/null +++ b/.github/workflows/update_space.yml @@ -0,0 +1,28 @@ +name: Run Python script + +on: + push: + branches: + - n + +jobs: + build: + runs-on: ubuntu-latest + + steps: + - name: Checkout + uses: actions/checkout@v2 + + - name: Set up Python + uses: actions/setup-python@v2 + with: + python-version: '3.9' + + - name: Install Gradio + run: python -m pip install gradio + + - name: Log in to Hugging Face + run: python -c 'import huggingface_hub; huggingface_hub.login(token="${{ secrets.hf_token }}")' + + - name: Deploy to Spaces + run: gradio deploy diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..45efd306a875b0c84e4f3553187832a4055c92fb --- /dev/null +++ b/.gitignore @@ -0,0 +1,165 @@ +wandb/ +weclone_archive-my/ +**/pycache/ +events.out.tfevents.* +归档/ +*.pt +*.npz +*nohup.out +*log.txt +*cookie.bin +*.gradio/ + +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +pip-wheel-metadata/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +.python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + + +*.zip +LLaMA-Factory +chatglm3-6b +cache +archive +model_output* +data/test +.vscode +*-my*.* +*.csv +*test.* +*users.json +Spark-TTS-0.5B/ +uv.lock +output* +*.out + +Qwen*/ +settings.jsonc +settings.json +dataset/blocked_words.json +dataset/wechat/* diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..ada1a8176b23b97e70d89300c3f06c3c471bec2c --- /dev/null +++ b/LICENSE @@ -0,0 +1,661 @@ + GNU AFFERO GENERAL PUBLIC LICENSE + Version 3, 19 November 2007 + + Copyright (C) 2007 Free Software Foundation, Inc. + Everyone is permitted to copy and distribute verbatim copies + of this license document, but changing it is not allowed. + + Preamble + + The GNU Affero General Public License is a free, copyleft license for +software and other kinds of works, specifically designed to ensure +cooperation with the community in the case of network server software. + + The licenses for most software and other practical works are designed +to take away your freedom to share and change the works. By contrast, +our General Public Licenses are intended to guarantee your freedom to +share and change all versions of a program--to make sure it remains free +software for all its users. + + When we speak of free software, we are referring to freedom, not +price. Our General Public Licenses are designed to make sure that you +have the freedom to distribute copies of free software (and charge for +them if you wish), that you receive source code or can get it if you +want it, that you can change the software or use pieces of it in new +free programs, and that you know you can do these things. + + Developers that use our General Public Licenses protect your rights +with two steps: (1) assert copyright on the software, and (2) offer +you this License which gives you legal permission to copy, distribute +and/or modify the software. + + A secondary benefit of defending all users' freedom is that +improvements made in alternate versions of the program, if they +receive widespread use, become available for other developers to +incorporate. Many developers of free software are heartened and +encouraged by the resulting cooperation. However, in the case of +software used on network servers, this result may fail to come about. +The GNU General Public License permits making a modified version and +letting the public access it on a server without ever releasing its +source code to the public. + + The GNU Affero General Public License is designed specifically to +ensure that, in such cases, the modified source code becomes available +to the community. It requires the operator of a network server to +provide the source code of the modified version running there to the +users of that server. Therefore, public use of a modified version, on +a publicly accessible server, gives the public access to the source +code of the modified version. + + An older license, called the Affero General Public License and +published by Affero, was designed to accomplish similar goals. This is +a different license, not a version of the Affero GPL, but Affero has +released a new version of the Affero GPL which permits relicensing under +this license. + + The precise terms and conditions for copying, distribution and +modification follow. + + TERMS AND CONDITIONS + + 0. Definitions. + + "This License" refers to version 3 of the GNU Affero General Public License. + + "Copyright" also means copyright-like laws that apply to other kinds of +works, such as semiconductor masks. + + "The Program" refers to any copyrightable work licensed under this +License. Each licensee is addressed as "you". "Licensees" and +"recipients" may be individuals or organizations. + + To "modify" a work means to copy from or adapt all or part of the work +in a fashion requiring copyright permission, other than the making of an +exact copy. The resulting work is called a "modified version" of the +earlier work or a work "based on" the earlier work. + + A "covered work" means either the unmodified Program or a work based +on the Program. + + To "propagate" a work means to do anything with it that, without +permission, would make you directly or secondarily liable for +infringement under applicable copyright law, except executing it on a +computer or modifying a private copy. Propagation includes copying, +distribution (with or without modification), making available to the +public, and in some countries other activities as well. + + To "convey" a work means any kind of propagation that enables other +parties to make or receive copies. Mere interaction with a user through +a computer network, with no transfer of a copy, is not conveying. + + An interactive user interface displays "Appropriate Legal Notices" +to the extent that it includes a convenient and prominently visible +feature that (1) displays an appropriate copyright notice, and (2) +tells the user that there is no warranty for the work (except to the +extent that warranties are provided), that licensees may convey the +work under this License, and how to view a copy of this License. If +the interface presents a list of user commands or options, such as a +menu, a prominent item in the list meets this criterion. + + 1. Source Code. + + The "source code" for a work means the preferred form of the work +for making modifications to it. "Object code" means any non-source +form of a work. + + A "Standard Interface" means an interface that either is an official +standard defined by a recognized standards body, or, in the case of +interfaces specified for a particular programming language, one that +is widely used among developers working in that language. + + The "System Libraries" of an executable work include anything, other +than the work as a whole, that (a) is included in the normal form of +packaging a Major Component, but which is not part of that Major +Component, and (b) serves only to enable use of the work with that +Major Component, or to implement a Standard Interface for which an +implementation is available to the public in source code form. A +"Major Component", in this context, means a major essential component +(kernel, window system, and so on) of the specific operating system +(if any) on which the executable work runs, or a compiler used to +produce the work, or an object code interpreter used to run it. + + The "Corresponding Source" for a work in object code form means all +the source code needed to generate, install, and (for an executable +work) run the object code and to modify the work, including scripts to +control those activities. However, it does not include the work's +System Libraries, or general-purpose tools or generally available free +programs which are used unmodified in performing those activities but +which are not part of the work. For example, Corresponding Source +includes interface definition files associated with source files for +the work, and the source code for shared libraries and dynamically +linked subprograms that the work is specifically designed to require, +such as by intimate data communication or control flow between those +subprograms and other parts of the work. + + The Corresponding Source need not include anything that users +can regenerate automatically from other parts of the Corresponding +Source. + + The Corresponding Source for a work in source code form is that +same work. + + 2. Basic Permissions. + + All rights granted under this License are granted for the term of +copyright on the Program, and are irrevocable provided the stated +conditions are met. This License explicitly affirms your unlimited +permission to run the unmodified Program. The output from running a +covered work is covered by this License only if the output, given its +content, constitutes a covered work. This License acknowledges your +rights of fair use or other equivalent, as provided by copyright law. + + You may make, run and propagate covered works that you do not +convey, without conditions so long as your license otherwise remains +in force. You may convey covered works to others for the sole purpose +of having them make modifications exclusively for you, or provide you +with facilities for running those works, provided that you comply with +the terms of this License in conveying all material for which you do +not control copyright. Those thus making or running the covered works +for you must do so exclusively on your behalf, under your direction +and control, on terms that prohibit them from making any copies of +your copyrighted material outside their relationship with you. + + Conveying under any other circumstances is permitted solely under +the conditions stated below. Sublicensing is not allowed; section 10 +makes it unnecessary. + + 3. Protecting Users' Legal Rights From Anti-Circumvention Law. + + No covered work shall be deemed part of an effective technological +measure under any applicable law fulfilling obligations under article +11 of the WIPO copyright treaty adopted on 20 December 1996, or +similar laws prohibiting or restricting circumvention of such +measures. + + When you convey a covered work, you waive any legal power to forbid +circumvention of technological measures to the extent such circumvention +is effected by exercising rights under this License with respect to +the covered work, and you disclaim any intention to limit operation or +modification of the work as a means of enforcing, against the work's +users, your or third parties' legal rights to forbid circumvention of +technological measures. + + 4. Conveying Verbatim Copies. + + You may convey verbatim copies of the Program's source code as you +receive it, in any medium, provided that you conspicuously and +appropriately publish on each copy an appropriate copyright notice; +keep intact all notices stating that this License and any +non-permissive terms added in accord with section 7 apply to the code; +keep intact all notices of the absence of any warranty; and give all +recipients a copy of this License along with the Program. + + You may charge any price or no price for each copy that you convey, +and you may offer support or warranty protection for a fee. + + 5. Conveying Modified Source Versions. + + You may convey a work based on the Program, or the modifications to +produce it from the Program, in the form of source code under the +terms of section 4, provided that you also meet all of these conditions: + + a) The work must carry prominent notices stating that you modified + it, and giving a relevant date. + + b) The work must carry prominent notices stating that it is + released under this License and any conditions added under section + 7. This requirement modifies the requirement in section 4 to + "keep intact all notices". + + c) You must license the entire work, as a whole, under this + License to anyone who comes into possession of a copy. This + License will therefore apply, along with any applicable section 7 + additional terms, to the whole of the work, and all its parts, + regardless of how they are packaged. This License gives no + permission to license the work in any other way, but it does not + invalidate such permission if you have separately received it. + + d) If the work has interactive user interfaces, each must display + Appropriate Legal Notices; however, if the Program has interactive + interfaces that do not display Appropriate Legal Notices, your + work need not make them do so. + + A compilation of a covered work with other separate and independent +works, which are not by their nature extensions of the covered work, +and which are not combined with it such as to form a larger program, +in or on a volume of a storage or distribution medium, is called an +"aggregate" if the compilation and its resulting copyright are not +used to limit the access or legal rights of the compilation's users +beyond what the individual works permit. Inclusion of a covered work +in an aggregate does not cause this License to apply to the other +parts of the aggregate. + + 6. Conveying Non-Source Forms. + + You may convey a covered work in object code form under the terms +of sections 4 and 5, provided that you also convey the +machine-readable Corresponding Source under the terms of this License, +in one of these ways: + + a) Convey the object code in, or embodied in, a physical product + (including a physical distribution medium), accompanied by the + Corresponding Source fixed on a durable physical medium + customarily used for software interchange. + + b) Convey the object code in, or embodied in, a physical product + (including a physical distribution medium), accompanied by a + written offer, valid for at least three years and valid for as + long as you offer spare parts or customer support for that product + model, to give anyone who possesses the object code either (1) a + copy of the Corresponding Source for all the software in the + product that is covered by this License, on a durable physical + medium customarily used for software interchange, for a price no + more than your reasonable cost of physically performing this + conveying of source, or (2) access to copy the + Corresponding Source from a network server at no charge. + + c) Convey individual copies of the object code with a copy of the + written offer to provide the Corresponding Source. This + alternative is allowed only occasionally and noncommercially, and + only if you received the object code with such an offer, in accord + with subsection 6b. + + d) Convey the object code by offering access from a designated + place (gratis or for a charge), and offer equivalent access to the + Corresponding Source in the same way through the same place at no + further charge. You need not require recipients to copy the + Corresponding Source along with the object code. If the place to + copy the object code is a network server, the Corresponding Source + may be on a different server (operated by you or a third party) + that supports equivalent copying facilities, provided you maintain + clear directions next to the object code saying where to find the + Corresponding Source. Regardless of what server hosts the + Corresponding Source, you remain obligated to ensure that it is + available for as long as needed to satisfy these requirements. + + e) Convey the object code using peer-to-peer transmission, provided + you inform other peers where the object code and Corresponding + Source of the work are being offered to the general public at no + charge under subsection 6d. + + A separable portion of the object code, whose source code is excluded +from the Corresponding Source as a System Library, need not be +included in conveying the object code work. + + A "User Product" is either (1) a "consumer product", which means any +tangible personal property which is normally used for personal, family, +or household purposes, or (2) anything designed or sold for incorporation +into a dwelling. In determining whether a product is a consumer product, +doubtful cases shall be resolved in favor of coverage. For a particular +product received by a particular user, "normally used" refers to a +typical or common use of that class of product, regardless of the status +of the particular user or of the way in which the particular user +actually uses, or expects or is expected to use, the product. A product +is a consumer product regardless of whether the product has substantial +commercial, industrial or non-consumer uses, unless such uses represent +the only significant mode of use of the product. + + "Installation Information" for a User Product means any methods, +procedures, authorization keys, or other information required to install +and execute modified versions of a covered work in that User Product from +a modified version of its Corresponding Source. The information must +suffice to ensure that the continued functioning of the modified object +code is in no case prevented or interfered with solely because +modification has been made. + + If you convey an object code work under this section in, or with, or +specifically for use in, a User Product, and the conveying occurs as +part of a transaction in which the right of possession and use of the +User Product is transferred to the recipient in perpetuity or for a +fixed term (regardless of how the transaction is characterized), the +Corresponding Source conveyed under this section must be accompanied +by the Installation Information. But this requirement does not apply +if neither you nor any third party retains the ability to install +modified object code on the User Product (for example, the work has +been installed in ROM). + + The requirement to provide Installation Information does not include a +requirement to continue to provide support service, warranty, or updates +for a work that has been modified or installed by the recipient, or for +the User Product in which it has been modified or installed. Access to a +network may be denied when the modification itself materially and +adversely affects the operation of the network or violates the rules and +protocols for communication across the network. + + Corresponding Source conveyed, and Installation Information provided, +in accord with this section must be in a format that is publicly +documented (and with an implementation available to the public in +source code form), and must require no special password or key for +unpacking, reading or copying. + + 7. Additional Terms. + + "Additional permissions" are terms that supplement the terms of this +License by making exceptions from one or more of its conditions. +Additional permissions that are applicable to the entire Program shall +be treated as though they were included in this License, to the extent +that they are valid under applicable law. If additional permissions +apply only to part of the Program, that part may be used separately +under those permissions, but the entire Program remains governed by +this License without regard to the additional permissions. + + When you convey a copy of a covered work, you may at your option +remove any additional permissions from that copy, or from any part of +it. (Additional permissions may be written to require their own +removal in certain cases when you modify the work.) You may place +additional permissions on material, added by you to a covered work, +for which you have or can give appropriate copyright permission. + + Notwithstanding any other provision of this License, for material you +add to a covered work, you may (if authorized by the copyright holders of +that material) supplement the terms of this License with terms: + + a) Disclaiming warranty or limiting liability differently from the + terms of sections 15 and 16 of this License; or + + b) Requiring preservation of specified reasonable legal notices or + author attributions in that material or in the Appropriate Legal + Notices displayed by works containing it; or + + c) Prohibiting misrepresentation of the origin of that material, or + requiring that modified versions of such material be marked in + reasonable ways as different from the original version; or + + d) Limiting the use for publicity purposes of names of licensors or + authors of the material; or + + e) Declining to grant rights under trademark law for use of some + trade names, trademarks, or service marks; or + + f) Requiring indemnification of licensors and authors of that + material by anyone who conveys the material (or modified versions of + it) with contractual assumptions of liability to the recipient, for + any liability that these contractual assumptions directly impose on + those licensors and authors. + + All other non-permissive additional terms are considered "further +restrictions" within the meaning of section 10. If the Program as you +received it, or any part of it, contains a notice stating that it is +governed by this License along with a term that is a further +restriction, you may remove that term. If a license document contains +a further restriction but permits relicensing or conveying under this +License, you may add to a covered work material governed by the terms +of that license document, provided that the further restriction does +not survive such relicensing or conveying. + + If you add terms to a covered work in accord with this section, you +must place, in the relevant source files, a statement of the +additional terms that apply to those files, or a notice indicating +where to find the applicable terms. + + Additional terms, permissive or non-permissive, may be stated in the +form of a separately written license, or stated as exceptions; +the above requirements apply either way. + + 8. Termination. + + You may not propagate or modify a covered work except as expressly +provided under this License. Any attempt otherwise to propagate or +modify it is void, and will automatically terminate your rights under +this License (including any patent licenses granted under the third +paragraph of section 11). + + However, if you cease all violation of this License, then your +license from a particular copyright holder is reinstated (a) +provisionally, unless and until the copyright holder explicitly and +finally terminates your license, and (b) permanently, if the copyright +holder fails to notify you of the violation by some reasonable means +prior to 60 days after the cessation. + + Moreover, your license from a particular copyright holder is +reinstated permanently if the copyright holder notifies you of the +violation by some reasonable means, this is the first time you have +received notice of violation of this License (for any work) from that +copyright holder, and you cure the violation prior to 30 days after +your receipt of the notice. + + Termination of your rights under this section does not terminate the +licenses of parties who have received copies or rights from you under +this License. If your rights have been terminated and not permanently +reinstated, you do not qualify to receive new licenses for the same +material under section 10. + + 9. Acceptance Not Required for Having Copies. + + You are not required to accept this License in order to receive or +run a copy of the Program. Ancillary propagation of a covered work +occurring solely as a consequence of using peer-to-peer transmission +to receive a copy likewise does not require acceptance. However, +nothing other than this License grants you permission to propagate or +modify any covered work. These actions infringe copyright if you do +not accept this License. Therefore, by modifying or propagating a +covered work, you indicate your acceptance of this License to do so. + + 10. Automatic Licensing of Downstream Recipients. + + Each time you convey a covered work, the recipient automatically +receives a license from the original licensors, to run, modify and +propagate that work, subject to this License. You are not responsible +for enforcing compliance by third parties with this License. + + An "entity transaction" is a transaction transferring control of an +organization, or substantially all assets of one, or subdividing an +organization, or merging organizations. If propagation of a covered +work results from an entity transaction, each party to that +transaction who receives a copy of the work also receives whatever +licenses to the work the party's predecessor in interest had or could +give under the previous paragraph, plus a right to possession of the +Corresponding Source of the work from the predecessor in interest, if +the predecessor has it or can get it with reasonable efforts. + + You may not impose any further restrictions on the exercise of the +rights granted or affirmed under this License. For example, you may +not impose a license fee, royalty, or other charge for exercise of +rights granted under this License, and you may not initiate litigation +(including a cross-claim or counterclaim in a lawsuit) alleging that +any patent claim is infringed by making, using, selling, offering for +sale, or importing the Program or any portion of it. + + 11. Patents. + + A "contributor" is a copyright holder who authorizes use under this +License of the Program or a work on which the Program is based. The +work thus licensed is called the contributor's "contributor version". + + A contributor's "essential patent claims" are all patent claims +owned or controlled by the contributor, whether already acquired or +hereafter acquired, that would be infringed by some manner, permitted +by this License, of making, using, or selling its contributor version, +but do not include claims that would be infringed only as a +consequence of further modification of the contributor version. For +purposes of this definition, "control" includes the right to grant +patent sublicenses in a manner consistent with the requirements of +this License. + + Each contributor grants you a non-exclusive, worldwide, royalty-free +patent license under the contributor's essential patent claims, to +make, use, sell, offer for sale, import and otherwise run, modify and +propagate the contents of its contributor version. + + In the following three paragraphs, a "patent license" is any express +agreement or commitment, however denominated, not to enforce a patent +(such as an express permission to practice a patent or covenant not to +sue for patent infringement). To "grant" such a patent license to a +party means to make such an agreement or commitment not to enforce a +patent against the party. + + If you convey a covered work, knowingly relying on a patent license, +and the Corresponding Source of the work is not available for anyone +to copy, free of charge and under the terms of this License, through a +publicly available network server or other readily accessible means, +then you must either (1) cause the Corresponding Source to be so +available, or (2) arrange to deprive yourself of the benefit of the +patent license for this particular work, or (3) arrange, in a manner +consistent with the requirements of this License, to extend the patent +license to downstream recipients. "Knowingly relying" means you have +actual knowledge that, but for the patent license, your conveying the +covered work in a country, or your recipient's use of the covered work +in a country, would infringe one or more identifiable patents in that +country that you have reason to believe are valid. + + If, pursuant to or in connection with a single transaction or +arrangement, you convey, or propagate by procuring conveyance of, a +covered work, and grant a patent license to some of the parties +receiving the covered work authorizing them to use, propagate, modify +or convey a specific copy of the covered work, then the patent license +you grant is automatically extended to all recipients of the covered +work and works based on it. + + A patent license is "discriminatory" if it does not include within +the scope of its coverage, prohibits the exercise of, or is +conditioned on the non-exercise of one or more of the rights that are +specifically granted under this License. You may not convey a covered +work if you are a party to an arrangement with a third party that is +in the business of distributing software, under which you make payment +to the third party based on the extent of your activity of conveying +the work, and under which the third party grants, to any of the +parties who would receive the covered work from you, a discriminatory +patent license (a) in connection with copies of the covered work +conveyed by you (or copies made from those copies), or (b) primarily +for and in connection with specific products or compilations that +contain the covered work, unless you entered into that arrangement, +or that patent license was granted, prior to 28 March 2007. + + Nothing in this License shall be construed as excluding or limiting +any implied license or other defenses to infringement that may +otherwise be available to you under applicable patent law. + + 12. No Surrender of Others' Freedom. + + If conditions are imposed on you (whether by court order, agreement or +otherwise) that contradict the conditions of this License, they do not +excuse you from the conditions of this License. If you cannot convey a +covered work so as to satisfy simultaneously your obligations under this +License and any other pertinent obligations, then as a consequence you may +not convey it at all. For example, if you agree to terms that obligate you +to collect a royalty for further conveying from those to whom you convey +the Program, the only way you could satisfy both those terms and this +License would be to refrain entirely from conveying the Program. + + 13. Remote Network Interaction; Use with the GNU General Public License. + + Notwithstanding any other provision of this License, if you modify the +Program, your modified version must prominently offer all users +interacting with it remotely through a computer network (if your version +supports such interaction) an opportunity to receive the Corresponding +Source of your version by providing access to the Corresponding Source +from a network server at no charge, through some standard or customary +means of facilitating copying of software. This Corresponding Source +shall include the Corresponding Source for any work covered by version 3 +of the GNU General Public License that is incorporated pursuant to the +following paragraph. + + Notwithstanding any other provision of this License, you have +permission to link or combine any covered work with a work licensed +under version 3 of the GNU General Public License into a single +combined work, and to convey the resulting work. The terms of this +License will continue to apply to the part which is the covered work, +but the work with which it is combined will remain governed by version +3 of the GNU General Public License. + + 14. Revised Versions of this License. + + The Free Software Foundation may publish revised and/or new versions of +the GNU Affero General Public License from time to time. Such new versions +will be similar in spirit to the present version, but may differ in detail to +address new problems or concerns. + + Each version is given a distinguishing version number. If the +Program specifies that a certain numbered version of the GNU Affero General +Public License "or any later version" applies to it, you have the +option of following the terms and conditions either of that numbered +version or of any later version published by the Free Software +Foundation. If the Program does not specify a version number of the +GNU Affero General Public License, you may choose any version ever published +by the Free Software Foundation. + + If the Program specifies that a proxy can decide which future +versions of the GNU Affero General Public License can be used, that proxy's +public statement of acceptance of a version permanently authorizes you +to choose that version for the Program. + + Later license versions may give you additional or different +permissions. However, no additional obligations are imposed on any +author or copyright holder as a result of your choosing to follow a +later version. + + 15. Disclaimer of Warranty. + + THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY +APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT +HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY +OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, +THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM +IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF +ALL NECESSARY SERVICING, REPAIR OR CORRECTION. + + 16. Limitation of Liability. + + IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING +WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS +THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY +GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE +USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF +DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD +PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS), +EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF +SUCH DAMAGES. + + 17. Interpretation of Sections 15 and 16. + + If the disclaimer of warranty and limitation of liability provided +above cannot be given local legal effect according to their terms, +reviewing courts shall apply local law that most closely approximates +an absolute waiver of all civil liability in connection with the +Program, unless a warranty or assumption of liability accompanies a +copy of the Program in return for a fee. + + END OF TERMS AND CONDITIONS + + How to Apply These Terms to Your New Programs + + If you develop a new program, and you want it to be of the greatest +possible use to the public, the best way to achieve this is to make it +free software which everyone can redistribute and change under these terms. + + To do so, attach the following notices to the program. It is safest +to attach them to the start of each source file to most effectively +state the exclusion of warranty; and each file should have at least +the "copyright" line and a pointer to where the full notice is found. + + + Copyright (C) + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU Affero General Public License as published + by the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Affero General Public License for more details. + + You should have received a copy of the GNU Affero General Public License + along with this program. If not, see . + +Also add information on how to contact you by electronic and paper mail. + + If your software can interact with users remotely through a computer +network, you should also make sure that it provides a way for users to +get its source. For example, if your program is a web application, its +interface could display a "Source" link that leads users to an archive +of the code. There are many ways you could offer source, and different +solutions will be better for different programs; see section 13 for the +specific requirements. + + You should also get your employer (if you work as a programmer) or school, +if any, to sign a "copyright disclaimer" for the program, if necessary. +For more information on this, and how to apply and follow the GNU AGPL, see +. diff --git a/README.md b/README.md index b19cfa0d519d7a0fe6a031efb932e89955d9cd6c..8ff13c3a4e233131a03457fa456a4c0b1f9b2db0 100644 --- a/README.md +++ b/README.md @@ -1,12 +1,277 @@ ---- -title: Cren -emoji: 🏢 -colorFrom: yellow -colorTo: blue -sdk: gradio -sdk_version: 5.32.0 -app_file: app.py -pinned: false ---- - -Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference +--- +title: cren +app_file: spaces_app.py +sdk: gradio +sdk_version: 5.21.0 +--- +![download](https://github.com/user-attachments/assets/5842e84e-004f-4afd-9373-af64e9575b78) +

🚀 One-stop solution for creating your digital avatar from chat history 💡

+

🚀从聊天记录创造数字分身的一站式解决方案💡

+ + +
+ +[![GitHub stars](https://img.shields.io/github/stars/xming521/WeClone?style=for-the-badge&logo=github&label=Stars&logoColor=white&color=ffda65)](https://github.com/xming521/WeClone/stargazers) +[![GitHub release](https://img.shields.io/github/v/release/xming521/WeClone?style=for-the-badge&logo=github&label=Release&logoColor=white&color=06d094)](https://github.com/xming521/WeClone/releases) + + WeClone① + +[![Twitter](https://img.shields.io/badge/Twitter-@weclone567-000000?style=for-the-badge&logo=x&logoColor=white)](https://x.com/weclone567) +[![Telegram](https://img.shields.io/badge/Telegram-2CA5E0?style=for-the-badge&logo=telegram&logoColor=white)](https://t.me/+JEdak4m0XEQ3NGNl) + +Featured|HelloGitHub +xming521%2FWeClone | Trendshift +Ask DeepWiki +
+ +

+ 项目主页 | + 项目文档 | + Windows部署指南 | + Linux部署指南【保姆级】 +

+ +> [!IMPORTANT] +>

WhatsApp and Telegram chat logs integration for digital avatar creation is coming !

+ +## ✨核心功能 +- 💫 涵盖打造数字分身的全链路方案,包括聊天数据导出、预处理、模型训练、部署 +- 💬 使用微信聊天记录微调LLM,让大模型有"那味儿" +- 🔗 绑定到微信、QQ、Telegram、企微、飞书机器人,实现自己的数字分身 +- 🛡️ 隐私信息过滤,本地化微调部署,数据安全可控 + +## 📋特性与说明 + +> [!IMPORTANT] +> - WeClone仍在快速迭代期,当前效果不代表最终效果。 +> - 微调LLM效果很大程度取决于模型大小、聊天数据的数量和质量,理论上模型越大,数据越多,效果越好。 +> - Windows环境未进行严格测试,可以使用WSL作为运行环境。详细教程可点击[Windows部署指南](https://blog.051088.xyz/2025/05/14/WeClone-%E7%94%A8%E5%BE%AE%E4%BF%A1%E8%81%8A%E5%A4%A9%E8%AE%B0%E5%BD%95%E6%89%93%E9%80%A0%E8%87%AA%E5%B7%B1%E7%9A%84AI%E6%95%B0%E5%AD%97%E5%88%86%E8%BA%AB/)查看。 + +### 硬件要求 + +项目默认使用Qwen2.5-7B-Instruct模型,LoRA方法对sft阶段微调,大约需要16GB显存。也可以使用[LLaMA Factory](https://github.com/hiyouga/LLaMA-Factory/blob/main/README_zh.md#%E6%A8%A1%E5%9E%8B)支持的其他模型和方法。 + +需要显存的估算值: +| 方法 | 精度 | 7B | 14B | 30B | 70B | `x`B | +| ------------------------------- | ---- | ----- | ----- | ----- | ------ | ------- | +| Full (`bf16` or `fp16`) | 32 | 120GB | 240GB | 600GB | 1200GB | `18x`GB | +| Full (`pure_bf16`) | 16 | 60GB | 120GB | 300GB | 600GB | `8x`GB | +| Freeze/LoRA/GaLore/APOLLO/BAdam | 16 | 16GB | 32GB | 64GB | 160GB | `2x`GB | +| QLoRA | 8 | 10GB | 20GB | 40GB | 80GB | `x`GB | +| QLoRA | 4 | 6GB | 12GB | 24GB | 48GB | `x/2`GB | +| QLoRA | 2 | 4GB | 8GB | 16GB | 24GB | `x/4`GB | + + +## 环境搭建 +1.cuda安装(已安装可跳过,**要求版本12.4及以上**):[LLaMA Factory](https://llamafactory.readthedocs.io/zh-cn/latest/getting_started/installation.html#cuda) + +2.建议使用 [uv](https://docs.astral.sh/uv/)安装依赖,这是一个非常快速的 Python 环境管理器。安装uv后,您可以使用以下命令创建一个新的Python环境并安装依赖项,注意这不包含音频克隆功能的依赖: +```bash +git clone https://github.com/xming521/WeClone.git +cd WeClone +uv venv .venv --python=3.10 +source .venv/bin/activate # windows下执行 .venv\Scripts\activate +uv pip install --group main -e . +``` +> [!TIP] +> 如果要使用最新的模型进行微调,需要手动安装最新版LLaMA Factory:`uv pip install --upgrade git+https://github.com/hiyouga/LLaMA-Factory.git`,同时其他依赖版本也可能需要修改,例如vllm pytorch transforms + +3.将配置文件模板复制一份并重命名为`settings.jsonc`,后续配置修改在此文件进行: +```bash +cp settings.template.jsonc settings.jsonc +``` +> [!NOTE] +> 训练以及推理相关配置统一在文件`settings.jsonc` + +4.使用以下命令测试CUDA环境是否正确配置并可被PyTorch识别,Mac不需要: +```bash +python -c "import torch; print('CUDA是否可用:', torch.cuda.is_available());" +``` + +5.(可选)安装FlashAttention,加速训练和推理:`uv pip install flash-attn --no-build-isolation` + +## 模型下载 +```bash +git lfs install +git clone https://www.modelscope.cn/Qwen/Qwen2.5-7B-Instruct.git +``` +下载有问题使用其他方式下载:[模型的下载](https://www.modelscope.cn/docs/models/download) + + +## 数据准备 + +请使用[PyWxDump](https://github.com/xaoyaoo/PyWxDump)提取微信聊天记录(不支持4.0版本微信)。可以先将手机的聊天记录迁移(备份)到电脑,数据量更多一些。下载软件并解密数据库后,点击聊天备份,导出类型为CSV,可以导出多个联系人(不建议使用群聊记录),然后将导出的位于`wxdump_tmp/export` 的 `csv` 文件夹放在`./dataset`目录即可,也就是不同人聊天记录的文件夹一起放在 `./dataset/csv`。 + +## 数据预处理 + +- 项目默认去除了数据中的手机号、身份证号、邮箱、网址。还在`settings.jsonc`中提供了一个禁用词词库`blocked_words`,可以自行添加需要过滤的词句(会默认去掉包括禁用词的整句)。 +> [!IMPORTANT] +> 🚨 请一定注意保护个人隐私,不要泄露个人信息! + +- 执行以下命令对数据进行处理,可以根据自己的聊天风格修改settings.jsonc的`make_dataset_args`。 +```bash +weclone-cli make-dataset +``` +- 目前仅支持时间窗口策略,根据`single_combine_time_window`将单人连续消息通过逗号连接合并为一句,根据`qa_match_time_window`匹配问答对。 +- 可以启用`clean_dataset`中的`enable_clean`选项,对数据进行清洗,以达到更好效果。* 当前系统支持使用 `llm judge` 对聊天记录进行打分,提供 **vllm 离线推理** 和 **API 在线推理** 两种方式。可通过将 `settings.jsonc` 文件中的 `"online_llm_clear": false` 修改为 `true` 来启用 API 在线推理模式,并配置相应的 `base_url`、`llm_api_key`、`model_name` 等参数。所有兼容 OpenAI 接口的模型均可接入。 +- 在获得 `llm 打分分数分布情况` 后,可通过设置 `accept_score` 参数筛选可接受的分数区间,同时可适当降低 `train_sft_args` 中的 `lora_dropout` 参数,以提升模型的拟合效果。 + +## 配置参数并微调模型 + +- (可选)修改 `settings.jsonc` 的 `model_name_or_path` 和 `template` 选择本地下载好的其他模型。 +- 修改`per_device_train_batch_size`以及`gradient_accumulation_steps`来调整显存占用。 +- 可以根据自己数据集的数量和质量修改`train_sft_args`的`num_train_epochs`、`lora_rank`、`lora_dropout`等参数。 + +### 单卡训练 +```bash +weclone-cli train-sft +``` +多卡环境单卡训练,需要先执行 `export CUDA_VISIBLE_DEVICES=0` + +### 多卡训练 +取消`settings.jsonc`中`deepspeed`行代码注释,使用以下命令多卡训练: +```bash +uv pip install deepspeed +deepspeed --num_gpus=使用显卡数量 weclone/train/train_sft.py +``` + +### 使用浏览器demo简单推理 +可以在这一步测试出合适的temperature、top_p值,修改settings.jsonc的`infer_args`后,供后续推理时使用。 +```bash +weclone-cli webchat-demo +``` + +### 使用接口进行推理 + +```bash +weclone-cli server +``` + +### 使用常见聊天问题测试 +不包含询问个人信息的问题,仅有日常聊天。测试结果在test_result-my.txt。 +```bash +weclone-cli server +weclone-cli test-model +``` + +## 🖼️ 微调效果 +使用Qwen2.5-14B-Instruct模型,大概3万条处理后的有效数据,loss降到了3.5左右的效果。 +
+截图 +
+ alt text + alt text + alt text + alt text +
+
+ + +## 🤖 部署到聊天机器人 + +### AstrBot + +[AstrBot](https://github.com/AstrBotDevs/AstrBot) 是易上手的多平台 LLM 聊天机器人及开发框架 ✨ 平台支持 QQ、QQ频道、Telegram、微信、企微、飞书。 + +使用步骤: +1. 部署 AstrBot +2. 在 AstrBot 中部署消息平台 +3. 执行 `weclone-cli server` 启动api服务 +4. 在 AstrBot 中新增服务提供商,类型选择OpenAI,API Base URL 根据AstrBot部署方式填写(例如docker部署可能为http://172.17.0.1:8005/v1) ,模型填写gpt-3.5-turbo,API Key随意填写一个 +5. 微调后不支持工具调用,请先关掉默认的工具,消息平台发送指令: `/tool off all`,否则会没有微调后的效果。 +6. 根据微调时使用的default_system,在 AstrBot 中设置系统提示词。 +![5](https://github.com/user-attachments/assets/19de7072-076a-4cdf-8ae6-46b9b89f536a) +> [!IMPORTANT] +> 检查api_service的日志,尽量保证大模型服务请求的参数和微调时一致,tool插件能力都关掉。 +7. 调整采样参数,例如temperature、top_p、top_k等 +[配置自定义的模型参数](https://astrbot.app/config/model-config.html#%E9%85%8D%E7%BD%AE%E8%87%AA%E5%AE%9A%E4%B9%89%E7%9A%84%E6%A8%A1%E5%9E%8B%E5%8F%82%E6%95%B0) + +### LangBot + +[LangBot](https://github.com/RockChinQ/LangBot) 是一个开源的接入全球多种即时通信平台的 LLM 机器人平台,适合各种场景使用。 + +1. [部署 LangBot](https://github.com/RockChinQ/LangBot#-%E5%BC%80%E5%A7%8B%E4%BD%BF%E7%94%A8) +2. 在 LangBot 中添加一个机器人 +4. 在模型页添加新模型,名称`gpt-3.5-turbo`,供应商选择 OpenAI,填写 请求 URL 为 WeClone 的地址,详细连接方式可以参考[文档](https://docs.langbot.app/zh/workshop/network-details.html),API Key 任意填写。 + +image + +6. 在流水线配置中选择刚才添加的模型,或修改提示词配置 + +image + +## 📌 路线图 +- [ ] 更丰富的上下文:包括上下文对话、聊天对象信息、时间等 + 思考 +- [ ] Memory 支持 +- [ ] 支持多模态 +- [ ] 数据增强 +- [ ] 支持GUI + +## 问题解决 +- 微调问题:[LLaMA-Factory| FAQs | 常见问题](https://github.com/hiyouga/LLaMA-Factory/issues/4614) 或者更方便的 [![更方便的Ask DeepWiki](https://deepwiki.com/badge.svg)](https://deepwiki.com/hiyouga/LLaMA-Factory) + +## ❤️ 贡献代码 + +欢迎任何 Issues/Pull Requests! + +你可以通过查看Issues或帮助审核 PR(拉取请求)来贡献。对于新功能的添加,请先通过 Issue 讨论。 +运行`uv pip install --group dev -e .`安装开发依赖。 +项目使用`pytest`测试(测试脚本待完善),`pyright`检查类型,`ruff`检查代码格式。 + + +## ⚠️ 免责声明 +> [!CAUTION] +> 请勿用于非法用途,否则后果自负。 +
+1. 使用目的 + +* 本项目仅供学习交流使用,**请勿用于非法用途**,**请勿用于非法用途**,**请勿用于非法用途**,否则后果自负。 +* 用户理解并同意,任何违反法律法规、侵犯他人合法权益的行为,均与本项目及其开发者无关,后果由用户自行承担。 + +2. 使用期限 + +* 您应该在下载保存使用本项目的24小时内,删除本项目的源代码和程序;超出此期限的任何使用行为,一概与本项目及其开发者无关。 + +3. 操作规范 + +* 本项目仅允许在授权情况下使用数据训练,严禁用于非法目的,否则自行承担所有相关责任;用户如因违反此规定而引发的任何法律责任,将由用户自行承担,与本项目及其开发者无关。 +* 严禁用于窃取他人隐私,严禁用于窃取他人隐私,严禁用于窃取他人隐私,否则自行承担所有相关责任。 + +4. 免责声明接受 + +* 下载、保存、进一步浏览源代码或者下载安装、编译使用本程序,表示你同意本警告,并承诺遵守它; + +5. 禁止用于非法测试或渗透 + +* 禁止利用本项目的相关技术从事非法测试或渗透,禁止利用本项目的相关代码或相关技术从事任何非法工作,如因此产生的一切不良后果与本项目及其开发者无关。 +* 任何因此产生的不良后果,包括但不限于数据泄露、系统瘫痪、侵犯隐私等,均与本项目及其开发者无关,责任由用户自行承担。 + +6. 免责声明修改 + +* 本免责声明可能根据项目运行情况和法律法规的变化进行修改和调整。用户应定期查阅本页面以获取最新版本的免责声明,使用本项目时应遵守最新版本的免责声明。 + +7. 其他 + +* 除本免责声明规定外,用户在使用本项目过程中应遵守相关的法律法规和道德规范。对于因用户违反相关规定而引发的任何纠纷或损失,本项目及其开发者不承担任何责任。 + +* 请用户慎重阅读并理解本免责声明的所有内容,确保在使用本项目时严格遵守相关规定。 + +
+请用户慎重阅读并理解本免责声明的所有内容,确保在使用本项目时严格遵守相关规定。 + +
+
+
+ +## ⭐ Star History +> [!TIP] +> 如果本项目对您有帮助,或者您关注本项目的未来发展,请给项目 Star,谢谢 + +
+ +[![Star History Chart](https://api.star-history.com/svg?repos=xming521/WeClone&type=Date)](https://www.star-history.com/#xming521/WeClone&Date) + +
+ + +
克隆我们,保留灵魂的芬芳
diff --git a/dataset/res_csv/pt/dataset_info.json b/dataset/res_csv/pt/dataset_info.json new file mode 100644 index 0000000000000000000000000000000000000000..77024e166a2e8329cc97863183184a894c810093 --- /dev/null +++ b/dataset/res_csv/pt/dataset_info.json @@ -0,0 +1,6 @@ +{"wechat-pt":{ + "file_name": "./pt-my.json", + "columns": { + "prompt": "c" + } +}} \ No newline at end of file diff --git a/dataset/res_csv/sft/dataset_info.json b/dataset/res_csv/sft/dataset_info.json new file mode 100644 index 0000000000000000000000000000000000000000..420740f2185753a1d055d267ed09951ca2c68f2c --- /dev/null +++ b/dataset/res_csv/sft/dataset_info.json @@ -0,0 +1,19 @@ +{ + "wechat-sft": { + "file_name": "sft-my-l.json", + "columns": { + "prompt": "instruction", + "response": "output", + "system": "system" + } + }, + "wechat-sft-with-history": { + "file_name": "sft-my-l.json", + "columns": { + "prompt": "instruction", + "response": "output", + "system": "system", + "history": "history" + } + } +} \ No newline at end of file diff --git a/dataset/test_data-privacy.json b/dataset/test_data-privacy.json new file mode 100644 index 0000000000000000000000000000000000000000..072796eed9bb875b556fafb760169ab8e17a7924 --- /dev/null +++ b/dataset/test_data-privacy.json @@ -0,0 +1,224 @@ +{ + "questions": [ + [ + "你多大了?" + ], + [ + "你有什么爱好吗?" + ], + [ + "你的理想是什么?", + "你觉得你离你的理想还有多远?" + ], + [ + "你最近在忙什么?", + "工作/学习顺利吗?", + "有什么有趣的事情发生吗?" + ], + [ + "你喜欢看什么类型的电影?", + "最近看过什么好看的电影吗?", + "你最喜欢的电影是什么?" + ], + [ + "你平时喜欢听什么音乐?", + "有推荐的歌手或乐队吗?", + "最近有喜欢的歌曲吗?" + ], + [ + "你喜欢旅游吗?", + "去过哪些地方?", + "最喜欢的旅游地是哪里?" + ], + [ + "你喜欢读书吗?", + "最近在读什么书?", + "最喜欢的书是哪本?" + ], + [ + "你平时喜欢运动吗?", + "喜欢做哪些运动?", + "有固定去锻炼吗?" + ], + [ + "周末一般都做些什么?", + "有没有什么特别的计划?", + "周末喜欢宅在家还是出去玩?" + ], + [ + "你喜欢宠物吗?", + "有养宠物吗?", + "最喜欢什么动物?" + ], + [ + "你喜欢吃什么类型的食物?", + "有推荐的餐厅吗?", + "最喜欢的菜是什么?" + ], + [ + "你喜欢什么样的天气?", + "最喜欢的季节是哪一个?", + "你觉得今天的天气怎么样?" + ], + [ + "你有看电视剧的习惯吗?", + "最近在追哪部剧?", + "最喜欢的电视剧是哪部?" + ], + [ + "你喜欢玩游戏吗?", + "最近在玩什么游戏?", + "有推荐的好玩的游戏吗?" + ], + [ + "你会做饭吗?", + "平时喜欢做哪些菜?", + "有没有特别拿手的菜?" + ], + [ + "你喜欢购物吗?", + "最近买了什么新东西?", + "有推荐的购物网站或店铺吗?" + ], + [ + "你平时怎么放松自己?", + "有特别的解压方式吗?", + "最喜欢的放松活动是什么?" + ], + [ + "你喜欢和朋友出去玩吗?", + "平时会和朋友去哪玩?", + "最近有没有和朋友聚会的计划?" + ], + [ + "你喜欢喝咖啡还是茶?", + "有没有特别喜欢的咖啡馆或茶馆?", + "最喜欢的饮品是什么?" + ], + [ + "你有兄弟姐妹吗?", + "和他们关系怎么样?", + "经常联系吗?" + ], + [ + "你喜欢读什么类型的杂志?", + "最近有看什么有趣的文章吗?", + "有订阅的杂志吗?" + ], + [ + "你喜欢看体育比赛吗?", + "最喜欢的运动项目是什么?", + "有没有特别支持的球队或运动员?" + ], + [ + "你会说其他语言吗?", + "最想学的语言是什么?", + "学习语言有什么技巧吗?" + ], + [ + "你对科技产品感兴趣吗?", + "最近有没有关注什么新科技?", + "最喜欢的电子产品是什么?" + ], + [ + "你喜欢喝什么样的饮料?", + "有没有自己调饮料的习惯?", + "最喜欢的饮品品牌是什么?" + ], + [ + "你平时用社交媒体吗?", + "常用哪些平台?", + "在社交媒体上做什么?" + ], + [ + "你对艺术感兴趣吗?", + "最喜欢的艺术家是谁?", + "有去过哪些艺术展览?" + ], + [ + "你喜欢DIY吗?", + "平时做些什么手工?", + "有没有完成的作品可以分享?" + ], + [ + "你喜欢种植植物吗?", + "有养什么植物?", + "最喜欢的植物是什么?" + ], + [ + "你喜欢拍照吗?", + "喜欢拍什么样的照片?", + "有没有用什么特别的摄影设备?" + ], + [ + "你喜欢听播客吗?", + "常听哪些主题的播客?", + "有没有推荐的播客?" + ], + [ + "你对历史感兴趣吗?", + "最喜欢哪个历史时期?", + "有没有特别喜欢的历史人物?" + ], + [ + "你喜欢画画吗?", + "平时画什么类型的画?", + "有参加过画展吗?" + ], + [ + "你喜欢写作吗?", + "平时写什么类型的文章?", + "有没有发表过作品?" + ], + [ + "你喜欢钓鱼吗?", + "平时去哪里钓鱼?", + "有没有钓到过什么大鱼?" + ], + [ + "你喜欢露营吗?", + "平时会去哪里露营?", + "有没有什么难忘的露营经历?" + ], + [ + "你喜欢摄影吗?", + "最喜欢拍什么题材?", + "有没有特别喜欢的摄影师?" + ], + [ + "你喜欢喝酒吗?", + "喜欢什么类型的酒?", + "有没有推荐的酒吧或品牌?" + ], + [ + "你喜欢滑雪吗?", + "平时去哪里滑雪?", + "有没有什么滑雪技巧分享?" + ], + [ + "你喜欢海边还是山里?", + "最喜欢去哪个地方度假?", + "有没有什么特别推荐的景点?" + ], + [ + "你喜欢参加音乐节吗?", + "参加过哪些音乐节?", + "最喜欢的音乐节是哪一个?" + ], + [ + "你喜欢跑步吗?", + "平时跑多长距离?", + "有没有参加过马拉松?" + ], + [ + "你喜欢参加聚会吗?", + "平时和朋友聚会做什么?", + "有没有什么有趣的聚会游戏?" + ], + [ + "你喜欢收集东西吗?", + "收集什么类型的物品?", + "有没有什么特别的收藏?" + ] + ] +} \ No newline at end of file diff --git a/dataset/test_data.json b/dataset/test_data.json new file mode 100644 index 0000000000000000000000000000000000000000..a6792e9d11626298d1a6616c0ce2f7362453bc82 --- /dev/null +++ b/dataset/test_data.json @@ -0,0 +1,157 @@ +{ + "questions": [ + [ + "吃了吗?", + "吃的什么啊", + "好吃吗", + "多少钱啊", + "可以请我吃吗" + ], + [ + "干嘛呢?", + "等会准备干什么去" + ], + [ + "在忙什么呢?", + "今天有什么特别的安排吗?", + "感觉怎么样?" + ], + [ + "最近有什么新鲜事发生吗?", + "有没有什么有趣的故事可以分享?" + ], + [ + "周末过得怎么样?", + "做了什么好玩的?" + ], + [ + "最近看了什么好看的电影或电视剧吗?", + "有什么推荐的吗?", + "大概讲了什么内容呀?" + ], + [ + "今天天气怎么样?", + "你那里呢?" + ], + [ + "最近工作/学习顺利吗?", + "有没有遇到什么挑战?" + ], + [ + "嗨,这会儿在忙啥呢?", + "今天有什么特别的安排不?", + "一切都还顺利吧?" + ], + [ + "你那边现在天气咋样啊?", + "是大晴天还是有点阴沉沉的?", + "冷不冷,或者热不热呀?" + ], + [ + "到饭点儿了没呀?", + "今天打算犒劳一下自己,吃点啥好吃的?", + "有没有啥特别想吃的,或者想去哪家馆子尝尝鲜?" + ], + [ + "最近网上有啥好玩儿的新闻或者梗吗?", + "刷到啥有意思的视频或者段子没?分享一下呗!" + ], + [ + "待会儿有啥打算呀?", + "今天剩下的时间准备怎么过呢?" + ], + [ + "今天有没有碰到啥让你眼前一亮的小事儿?", + "随便聊聊呗,有啥轻松点的话题不?" + ], + [ + "今天有啥新发现或者小感悟没?", + "感觉今天过得快不快?节奏怎么样?" + ], + [ + "你现在周围环境咋样,吵不吵?", + "今天出门溜达了没,外面人多不多呀?", + "瞅瞅窗外,有啥特别的景儿不?" + ], + [ + "吃饭了没啊?", + "吃的啥呀?合胃口不?" + ], + [ + "今天怎么样啊?累不累?", + "有啥事儿不?" + ], + [ + "最近身体还好吧?", + "没什么不舒服的地方吧?" + ], + [ + "今天忙不忙啊?", + "都干啥了呀?" + ], + [ + "家里都挺好的吧?", + "有啥需要帮忙的不?" + ], + [ + "今天出门了没?", + "外面冷不冷/热不热啊?多穿点/注意防暑。" + ], + [ + "最近有啥开心的事儿不?说来听听!", + "或者有啥烦心事儿,跟我说说?" + ], + [ + "晚上早点休息啊,别熬太晚。", + "睡得好不好啊最近?" + ], + [ + "缺啥东西不?跟我说。", + "钱够不够花呀?" + ], + [ + "今天看到啥有意思的了没?", + "或者有啥想跟我分享的?" + ], + [ + "周末有啥安排啊?", + "要不要一起吃个饭/出去转转?" + ], + [ + "最近常联系的那些朋友都还好不?", + "有空多聚聚。" + ], + [ + "工作/学习上还顺利吧?", + "别太给自己压力啊。" + ], + [ + "今天做了啥好吃的呀?", + "下次也给我尝尝呗!" + ], + [ + "有啥新闻没有啊最近?", + "跟我讲讲。" + ], + [ + "那谁谁谁最近怎么样了?", + "好久没听到他/她消息了。" + ], + [ + "今天心情好不好呀?", + "看你气色不错/有点疲惫。" + ], + [ + "有啥想吃的没?下次给你做/带。", + "或者想去哪儿玩,我陪你。" + ], + [ + "最近有没有看啥电视剧/电影啊?", + "有啥好看的推荐给我呗。" + ], + [ + "没事儿就早点回家/休息。", + "注意安全啊。" + ] + ] +} \ No newline at end of file diff --git a/ds_config.json b/ds_config.json new file mode 100644 index 0000000000000000000000000000000000000000..6b59ddcc45b1eede559645c85d673bca03569b3c --- /dev/null +++ b/ds_config.json @@ -0,0 +1,28 @@ +{ + "fp16": { + "enabled": "auto", + "loss_scale": 0, + "loss_scale_window": 1000, + "initial_scale_power": 16, + "hysteresis": 2, + "min_loss_scale": 1 + }, + "bf16": { + "enabled": "auto" + }, + "zero_optimization": { + "stage": 2, + "allgather_partitions": true, + "allgather_bucket_size": 5e8, + "overlap_comm": true, + "reduce_scatter": true, + "reduce_bucket_size": 5e8, + "contiguous_gradients": true + }, + "gradient_accumulation_steps": "auto", + "gradient_clipping": "auto", + "steps_per_print": 2000, + "train_batch_size": "auto", + "train_micro_batch_size_per_gpu": "auto", + "wall_clock_breakdown": false +} \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000000000000000000000000000000000000..6fd3a3d8ab10981a76e262773febbea1555c6f21 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,125 @@ +[project] +name = "WeClone" +version = "0.2.21" +description = "从聊天记录创造数字分身的一站式解决方案" +authors = [{ name = "xming521" }] +readme = "README.md" +requires-python = ">=3.10,<3.11" + +dependencies = [ + "pandas", + "commentjson", + "click", + "pydantic==2.10.6", + "setuptools>=78.1.0", + "loguru>=0.7.3", + "torch>=2.6.0", + "transformers==4.49.0", + "tomli; python_version < '3.11'", + "langchain", +] + +[tool.weclone] +# 配置文件的版本号,当配置文件结构或重要默认值发生变化时,应增加此版本号 +config_version = "0.2.21" + +# 配置文件更新日志 +config_changelog = """ +[0.2.1] - 2025-04-29 - 初始配置版本。 +[0.2.2] - 2025-05-01 - 增加llm清洗数据配置,blocked_words迁移到settings.jsonc统一配置文件。 +[0.2.21] - 2025-05-01 - 增加在线llm清洗数据配置,兼容openai风格接口。 +""" + +[dependency-groups] +# xcodec = ["xcodec2==0.1.3"] +sparktts = [ + "einops>=0.8.1", + "einx>=0.3.0", + "numpy==1.26.4", + "omegaconf>=2.3.0", + "packaging>=24.2", + "safetensors>=0.5.2", + "soundfile>=0.12.1", + "soxr>=0.5.0.post1", + "torchaudio>=2.6.0", + "tqdm>=4.66.5", +] +main = [ + "llamafactory>=0.9.2", + "openai==1.76.0", + "vllm==0.8.2; platform_system == 'Linux'", +] +dev = ["pytest", "pytest-order", "pyright", "ruff"] + +[project.scripts] +weclone-cli = "weclone.cli:cli" + +[tool.uv] +conflicts = [ + # [{ group = "wx" }, { group = "xcodec" }], +] + +[tool.uv.sources] +torch = [ + { index = "pytorch-cu124", marker = "platform_system == 'Windows'" }, + { index = "pytorch-cu124", marker = "platform_system == 'Linux'" }, +] +torchaudio = [ + { index = "pytorch-cu124", marker = "platform_system == 'Windows'" }, + { index = "pytorch-cu124", marker = "platform_system == 'Linux'" }, +] +torchvision = [ + { index = "pytorch-cu124", marker = "platform_system == 'Windows'" }, + { index = "pytorch-cu124", marker = "platform_system == 'Linux'" }, +] + + +[[tool.uv.index]] +url = "https://pypi.tuna.tsinghua.edu.cn/simple/" +default = true + +[[tool.uv.index]] +name = "pytorch-cu124" +url = "https://download.pytorch.org/whl/cu124" +explicit = true + +[tool.setuptools.packages.find] +where = ["."] # 表示在项目根目录开始查找 +include = ["weclone*"] # 只包含名为 weclone 的目录及其子包 +exclude = ["*tests*", "*archive*"] # 可以选择性排除其他模式,比如测试目录 + + +[tool.pyright] +typeCheckingMode = "basic" +include = ["weclone/data"] +exclude = ["**/archive", "**/tests"] +ignore = ["**/archive"] + +reportMissingImports = "error" +reportMissingTypeStubs = false + +pythonVersion = "3.10" +pythonPlatform = "Linux" + +[tool.ruff] +exclude = [ + "**/archive", + "**/tests", + "weclone-audio/src/server未完工", + "weclone-audio/src/Spark-TTS", +] +line-length = 120 + +lint.ignore = ["F403", "F405", "E501", "E402"] +lint.select = [ + "F", # Pyflakes + "W", # pycodestyle warnings + "E", # pycodestyle errors + "ASYNC", # flake8-async + "C4", # flake8-comprehensions + "Q", # flake8-quotes +] +target-version = "py310" + +[tool.pytest.ini_options] +addopts = "-x" diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..1746cc20a1f51aa5f773ca232e6237c49ebd5d5c Binary files /dev/null and b/requirements.txt differ diff --git a/settings.template.jsonc b/settings.template.jsonc new file mode 100644 index 0000000000000000000000000000000000000000..93ac4c4fc3e5a3f1a705bfa7ba89608cca00f2eb --- /dev/null +++ b/settings.template.jsonc @@ -0,0 +1,95 @@ +{ + "version": "0.2.21", + "common_args": { + "model_name_or_path": "./Qwen2.5-7B-Instruct", + "adapter_name_or_path": "./model_output", //同时做为train_sft_args的output_dir + "template": "qwen", + "default_system": "请你扮演一名人类,不要说自己是人工智能", + "finetuning_type": "lora", + "trust_remote_code": true + }, + "cli_args": { + "full_log": false + }, + "make_dataset_args": { + //数据处理配置 + "include_type": [ + "text", + // "image" + ], + "blocked_words": [ // 禁用词 + "例如 姓名", + "例如 密码", + "//....." + ], + "single_combine_strategy": "time_window", // 单人组成单句策略 + "qa_match_strategy": "time_window", // 组成qa策略 + "single_combine_time_window": 2, // 单人组成单句时间窗口(分钟), + "qa_match_time_window": 5, // 组成qa时间窗口(分钟), + "combine_msg_max_length": 256, // 组合后消息最大长度 配合cutoff_len 使用 + "prompt_with_history": false, // 是否在prompt中包含历史对话 + "clean_dataset": { + "enable_clean": true, + "clean_strategy": "llm", + "llm": { + "accept_score": 2, //可以接受的llm打分阈值,1分最差,5分最好,低于此分数的数据不会用于训练 + } + }, + "online_llm_clear": false, + "base_url": "https://xxx/v1", + "llm_api_key": "xxxxx", + "model_name": "xxx", //建议使用参数较大的模型,例如DeepSeek-V3 + "clean_batch_size": 10 + }, + "train_pt_args": { + //预训练微调配置 + "stage": "pt", + "dataset": "wechat-pt", + "dataset_dir": "./dataset/res_csv/pt", + "lora_target": "q_proj,v_proj", + "lora_rank": 2, + "lora_dropout": 0.1, + "output_dir": "model_output", + "overwrite_cache": true, + "per_device_train_batch_size": 1, + "gradient_accumulation_steps": 1, + "lr_scheduler_type": "cosine", + "logging_steps": 10, + "save_steps": 1000, + "learning_rate": 0.001, + "num_train_epochs": 30, + "plot_loss": true, + "fp16": true + }, + "train_sft_args": { + //微调配置 + "stage": "sft", + "dataset": "wechat-sft", + "dataset_dir": "./dataset/res_csv/sft", + "use_fast_tokenizer": true, + "lora_target": "q_proj,v_proj", + "lora_rank": 4, + "lora_dropout": 0.3, + "weight_decay": 0.1, + "overwrite_cache": true, + "per_device_train_batch_size": 8, + "gradient_accumulation_steps": 4, + "lr_scheduler_type": "cosine", + "cutoff_len": 256, + "logging_steps": 10, + "save_steps": 100, + "learning_rate": 1e-4, + "warmup_ratio": 0.1, + "num_train_epochs": 2, + "plot_loss": true, + "fp16": true, + "flash_attn": "fa2", + // "deepspeed": "ds_config.json" //多卡训练 + }, + "infer_args": { + "repetition_penalty": 1.2, + "temperature": 0.5, + "max_length": 50, + "top_p": 0.65 + } +} \ No newline at end of file diff --git a/spaces_app.py b/spaces_app.py new file mode 100644 index 0000000000000000000000000000000000000000..db55e5fd216ecbb829f0d0f9ff3a4d83f40b4879 --- /dev/null +++ b/spaces_app.py @@ -0,0 +1,53 @@ +import gradio as gr +from transformers import AutoModelForCausalLM, AutoTokenizer +import torch + +# 配置模型路径 - 使用您本地的模型目录 +MODEL_PATH = "./Qwen2.5-7B-Instruct" + +# 加载模型和分词器 +tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True) +model = AutoModelForCausalLM.from_pretrained( + MODEL_PATH, + torch_dtype=torch.bfloat16, + device_map="auto", + trust_remote_code=True +) + +# 聊天函数 +def chat(message, history): + history = history or [] + chat_history = "" + for human, assistant in history: + chat_history += f"<|im_start|>user\n{human}<|im_end|>\n" + chat_history += f"<|im_start|>assistant\n{assistant}<|im_end|>\n" + + prompt = f"{chat_history}<|im_start|>user\n{message}<|im_end|>\n<|im_start|>assistant\n" + + inputs = tokenizer(prompt, return_tensors="pt").to(model.device) + outputs = model.generate( + **inputs, + max_new_tokens=512, + do_sample=True, + temperature=0.7, + top_p=0.9, + repetition_penalty=1.1, + eos_token_id=tokenizer.eos_token_id + ) + response = tokenizer.decode( + outputs[0][inputs.input_ids.shape[1]:], + skip_special_tokens=True + ) + return response + +# 创建界面 +demo = gr.ChatInterface( + chat, + title="WeClone AI 助手", + description="基于 Qwen2.5-7B 的聊天演示", + theme="soft", + examples=["你好", "介绍一下你自己", "你能做什么?"] +) + +# 导出为可部署对象 +app = demo \ No newline at end of file diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/tests/full_pipe.jsonc b/tests/full_pipe.jsonc new file mode 100644 index 0000000000000000000000000000000000000000..a002ad5c74f22ed1b031d7c0aa53b711d6561300 --- /dev/null +++ b/tests/full_pipe.jsonc @@ -0,0 +1,89 @@ +{ + "version": "0.2.2", + "common_args": { + "model_name_or_path": "./Qwen2.5-3B-Instruct", + "adapter_name_or_path": "./model_output", //同时做为train_sft_args的output_dir + "template": "qwen", + "default_system": "请你扮演一名人类,不要说自己是人工智能", + "finetuning_type": "lora", + "trust_remote_code": true + }, + "cli_args": { + "full_log": false + }, + "make_dataset_args": { + //数据处理配置 + "include_type": [ + "文本" + ], + "blocked_words": [ // 禁用词 + "例如 姓名", + "例如 密码", + "//....." + ], + "single_combine_strategy": "time_window", // 单人组成单句策略 + "qa_match_strategy": "time_window", // 组成qa策略 + "single_combine_time_window": 2, // 单人组成单句时间窗口(分钟), + "qa_match_time_window": 5, // 组成qa时间窗口(分钟), + "combine_msg_max_length": 256, // 组合后消息最大长度 配合cutoff_len 使用 + "prompt_with_history": false, // 是否在prompt中包含历史对话 + "clean_dataset": { + "enable_clean": true, + "clean_strategy": "llm", + "llm": { + "accept_score": 2, //可以接受的llm打分阈值,1分最差,5分最好,低于此分数的数据不会用于训练 + } + } + }, + "train_pt_args": { + //预训练微调配置 + "stage": "pt", + "dataset": "wechat-pt", + "dataset_dir": "./dataset/res_csv/pt", + "lora_target": "q_proj,v_proj", + "lora_rank": 2, + "lora_dropout": 0.1, + "output_dir": "model_output", + "overwrite_cache": true, + "per_device_train_batch_size": 1, + "gradient_accumulation_steps": 1, + "lr_scheduler_type": "cosine", + "logging_steps": 10, + "save_steps": 1000, + "learning_rate": 0.001, + "num_train_epochs": 30, + "plot_loss": true, + "fp16": true + }, + "train_sft_args": { + //微调配置 + "stage": "sft", + "dataset": "wechat-sft", + "dataset_dir": "./dataset/res_csv/sft", + "use_fast_tokenizer": true, + "lora_target": "q_proj,v_proj", + "lora_rank": 4, + "lora_dropout": 0.3, + "weight_decay": 0.1, + "overwrite_cache": true, + "per_device_train_batch_size": 8, + "gradient_accumulation_steps": 4, + "lr_scheduler_type": "cosine", + "cutoff_len": 256, + "logging_steps": 5, + "save_steps": 10, + "learning_rate": 1e-4, + "warmup_ratio": 0.1, + "num_train_epochs": 1, + "plot_loss": true, + "fp16": true, + "flash_attn": "fa2", + // "deepspeed": "ds_config.json" //多卡训练 + }, + "infer_args": { + "repetition_penalty": 1.2, + "temperature": 0.5, + "max_length": 50, + "top_p": 0.65 + } +} \ No newline at end of file diff --git a/tests/test_full_pipe.py b/tests/test_full_pipe.py new file mode 100644 index 0000000000000000000000000000000000000000..e86a9e90647d05a64893cbb18d1665109b4c309e --- /dev/null +++ b/tests/test_full_pipe.py @@ -0,0 +1,154 @@ +import pytest +from unittest import mock +import sys +import os +import shutil +import functools +import subprocess +import time +from typing import Union, Optional, cast +from weclone.utils.log import logger + +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) +PROJECT_ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), '..')) +server_process: Optional[subprocess.Popen] = None + +test_logger = logger.bind() +test_logger.remove() +test_logger.add( + sys.stderr, + format="{message}", + colorize=True, + level="INFO", +) + +def print_test_header(test_name: str): + line_length = 100 + test_logger.info("\n" + "─" * line_length) + title = f" Testing Phase: {test_name} " + padding_total = line_length - len(title) + padding_left = padding_total // 2 + padding_right = padding_total - padding_left + test_logger.info(" " * padding_left + title + " " * padding_right) + test_logger.info("─" * line_length) + +def setup_make_dataset_test_data(): + PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), '..')) + DATASET_CSV_DIR = os.path.join(PROJECT_ROOT, "dataset", "csv") + + TESTS_DIR = os.path.dirname(__file__) + TEST_DATA_PERSON_DIR = os.path.join(TESTS_DIR, "tests_data", "test_person") + + os.makedirs(DATASET_CSV_DIR, exist_ok=True) + + if os.path.exists(DATASET_CSV_DIR) and os.listdir(DATASET_CSV_DIR): + if all(f.startswith('.') or f.lower() == 'readme.md' for f in os.listdir(DATASET_CSV_DIR)): + for item_name in os.listdir(TEST_DATA_PERSON_DIR): + source_item_path = os.path.join(TEST_DATA_PERSON_DIR, item_name) + if os.path.isfile(source_item_path) and item_name.lower().endswith('.csv'): + destination_item_path = os.path.join(DATASET_CSV_DIR, item_name) + shutil.copy2(source_item_path, destination_item_path) + + +def run_cli_command(command: list[str], timeout: int | None = None, background: bool = False) -> Union[subprocess.CompletedProcess, subprocess.Popen]: + """Execute a CLI command and return the result. + + Args: + command: List of commands to execute. + timeout: Timeout in seconds. + background: Whether to run in the background. + + Returns: + If background=True, returns a Popen object; otherwise, returns a CompletedProcess object. + """ + env = os.environ.copy() + env["WECLONE_CONFIG_PATH"] = "tests/full_pipe.jsonc" # Set environment variable + + if background: + process = subprocess.Popen( + [sys.executable, "-m", "weclone.cli"] + command, + stderr=subprocess.PIPE, + stdout=subprocess.PIPE, + text=True, + cwd=PROJECT_ROOT_DIR, + env=env + ) + time.sleep(2) + return process + else: + process = subprocess.run( + [sys.executable, "-m", "weclone.cli"] + command, + stderr=None, + stdout=None, + text=True, + cwd=PROJECT_ROOT_DIR, # Execute in the project root directory + timeout=timeout, + env=env # Pass the modified environment variables + ) + return process + +@pytest.mark.order(1) +def test_cli_make_dataset(): + """Test the make-dataset command.""" + print_test_header("make-dataset") + setup_make_dataset_test_data() + result = run_cli_command(["make-dataset"]) + assert result.returncode == 0, "make-dataset command execution failed" + +@pytest.mark.order(2) +def test_cli_train_sft(): + """Test the train-sft command.""" + print_test_header("train-sft") + try: + result = run_cli_command(["train-sft"]) + assert result.returncode == 0, "train-sft command failed or did not fail fast as expected" + except subprocess.TimeoutExpired: + test_logger.info("train-sft command terminated due to timeout, which is acceptable in testing, indicating the command has started execution.") + pass + except Exception as e: + pytest.fail(f"An unexpected error occurred during train-sft command execution: {e}") + +@pytest.mark.order(3) +def test_cli_webchat_demo(): + """Test the webchat-demo command.""" + print_test_header("webchat-demo") + + with mock.patch("weclone.eval.web_demo.main") as mock_main: + mock_main.return_value = None + try: + result = run_cli_command(["webchat-demo"], timeout=5) + assert result.returncode == 0, "webchat-demo command execution failed" + except subprocess.TimeoutExpired: + pass + +@pytest.mark.order(4) +def test_cli_server(): + """Test the server command. + + Start the server in the background, without blocking subsequent tests. + """ + print_test_header("server (background)") + global server_process + server_process = cast(subprocess.Popen, run_cli_command(["server"], background=True)) + assert server_process.poll() is None, "Server startup failed" + test_logger.info("服务器已在后台启动") + +@pytest.mark.order(5) +def test_cli_test_model(): + """Test the test-model command. + + Use the server for testing, and shut down the server after the test is complete. + """ + print_test_header("test-model") + try: + result = run_cli_command(["test-model"]) + assert result.returncode == 0, "test-model command execution failed" + finally: + global server_process + if server_process is not None and server_process.poll() is None: + test_logger.info("测试完成,正在关闭服务器...") + server_process.terminate() + server_process.wait(timeout=5) + if server_process.poll() is None: + server_process.kill() # Force kill if the process hasn't terminated + test_logger.info("服务器已关闭") diff --git a/torchvision.whl b/torchvision.whl new file mode 100644 index 0000000000000000000000000000000000000000..b51723e2b3d27562384f38c0e7ca32aeecaa07f0 --- /dev/null +++ b/torchvision.whl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:307e52c2887c1d2b50cc3581cf5f4c169130b8352462e361e71eeda19e0dd263 +size 5660713 diff --git a/weclone-audio/README.md b/weclone-audio/README.md new file mode 100644 index 0000000000000000000000000000000000000000..2edead418be7c10c201ed81c4c0fb145a25df97a --- /dev/null +++ b/weclone-audio/README.md @@ -0,0 +1,134 @@ +# WeClone-audio 模块 + +WeClone-audio 是一个使用微信语音消息克隆声音的模块,使用模型实现高质量语音合成。 +### 显存需求 +**Spark-TTS** 推荐 +- **0.5B 模型**: 约 4GB 显存 + +**Llasa** (已弃用) +- **3B 模型**: 约 16GB 显存 +- **1B 模型**: 约 9GB 显存 + + + + +## 1. 导出微信语音数据 + +### 1.1 准备工作 +- 使用 [PyWxDump](https://github.com/xaoyaoo/PyWxDump) 提取微信聊天记录 +- 下载软件并解密数据库 +- 点击聊天备份,导出类型选择"解密文件" + +### 1.2 环境配置 +语音导出仅支持Windows环境 +WeClone Audio使用uv作为包管理器。 +```bash +# 为 PyWxDump 创建 Python 环境和安装依赖 +# +uv venv .venv-wx --python=3.10 +.venv-wx\Scripts\activate +uv pip install pywxdump +``` + +### 1.3 导出语音文件 +```bash +python weclone-audio/src/get_sample_audio.py --db-path "导出数据库路径" --MsgSvrID "导出聊天记录的MsgSvrID字段" +``` + +## 2. 语音合成推理 +### Spark-TTS模型 + +**环境安装** +可不创建新环境,直接安装`sparktts`依赖组到WeClone共主环境 + +```bash +uv venv .venv-sparktts --python=3.10 +source .venv-sparktts/bin/activate +uv pip install --group sparktts -e . + +git clone https://github.com/SparkAudio/Spark-TTS.git weclone-audio/src/Spark-TTS +``` + + +**模型下载** + +通过python下载: +```python +from huggingface_hub import snapshot_download + +# 假设此 Python 代码在 weclone-audio 目录下运行 模型将下载到 weclone-audio/pretrained_models/Spark-TTS-0.5B +snapshot_download("SparkAudio/Spark-TTS-0.5B", local_dir="pretrained_models/Spark-TTS-0.5B") +``` + +或通过git下载: +```bash +# 假设当前在 weclone-audio 目录 +mkdir -p pretrained_models + +# Make sure you have git-lfs installed (https://git-lfs.com) +git lfs install +git clone https://huggingface.co/SparkAudio/Spark-TTS-0.5B pretrained_models/Spark-TTS-0.5B +``` +使用代码推理 +```python +import os +import SparkTTS +import soundfile as sf +import torch + +from SparkTTS import SparkTTS + +# 假设此 Python 代码在 weclone-audio 目录下运行 +# 模型路径相对于当前目录 +model_path = "pretrained_models/Spark-TTS-0.5B" +sample_audio = "sample.wav" +output_audio = "output.wav" + +model = SparkTTS(model_path, "cuda") + +with torch.no_grad(): + wav = model.inference( + text="晚上好啊,小可爱们,该睡觉了哦", + prompt_speech_path=sample_audio, # 使用相对路径 + prompt_text="对,这就是我万人敬仰的太乙真人,虽然有点婴儿肥,但也掩不住我逼人的帅气。", + ) + sf.write(output_audio, wav, samplerate=16000) # 使用相对路径 +``` +### Llasa模型 (已弃用) +### 2.1 环境配置 +```bash +# 创建并配置推理环境 +## 可不创建新环境,与LLaMA-Factory环境共用 +uv venv .venv-xcodec --python=3.9 +source .venv-xcodec/bin/activate +uv pip install --group xcodec -e . +# 退出环境 +deactivate + +# 系统依赖安装(如果需要) +sudo apt install python3-dev +sudo apt install build-essential +``` + +### 2.2 使用代码推理 +如果遇到问题,请尝试将参考音频转换为WAV或MP3格式,将其裁剪至15秒以内,并缩短提示文本。 +```python +import os +import soundfile as sf +# 假设 text_to_speech.py 位于 src/ 或其他可导入的位置 +from text_to_speech import TextToSpeech + + +sample_audio_text = "对,这就是我万人敬仰的太乙真人,虽然有点婴儿肥,但也掩不住我逼人的帅气。" # 示例音频文本 +# 假设此 Python 代码在 weclone-audio 目录下运行 +# 示例音频路径相对于当前目录 +sample_audio_path = "sample.wav" +output_audio = "output.wav" + + +tts = TextToSpeech(sample_audio_path, sample_audio_text) +target_text = "晚上好啊" # 生成目标文本 +result = tts.infer(target_text) +sf.write(output_audio, result[1], result[0]) # 使用相对路径 +``` + diff --git a/weclone-audio/src/Llasa/infer.py b/weclone-audio/src/Llasa/infer.py new file mode 100644 index 0000000000000000000000000000000000000000..81219eea68f52795bc0c6a916d93d59a21827b2e --- /dev/null +++ b/weclone-audio/src/Llasa/infer.py @@ -0,0 +1,12 @@ +import os +import soundfile as sf +from text_to_speech import TextToSpeech + + +sample_audio_text = "对,这就是我万人敬仰的太乙真人,虽然有点婴儿肥,但也掩不住我逼人的帅气。" # 示例音频文本 +sample_audio_path = os.path.join(os.path.dirname(__file__), "sample.wav") # 示例音频路径 +tts = TextToSpeech(sample_audio_path, sample_audio_text) +target_text = "晚上好啊" # 生成目标文本 +result = tts.infer(target_text) +sf.write(os.path.join(os.path.dirname(__file__), "output.wav"), result[1], result[0]) # 保存生成音频 + diff --git a/weclone-audio/src/Llasa/text_to_speech.py b/weclone-audio/src/Llasa/text_to_speech.py new file mode 100644 index 0000000000000000000000000000000000000000..2bb468b77b6222c97c8fc0979c53d8f85ebac683 --- /dev/null +++ b/weclone-audio/src/Llasa/text_to_speech.py @@ -0,0 +1,131 @@ +import os +from transformers import AutoTokenizer, AutoModelForCausalLM +import torch +import soundfile as sf +from xcodec2.modeling_xcodec2 import XCodec2Model +import torchaudio + + +class TextToSpeech: + def __init__(self, sample_audio_path, sample_audio_text): + self.sample_audio_text = sample_audio_text + # 初始化模型 + llasa_3b = "HKUSTAudio/Llasa-3B" + xcodec2 = "HKUSTAudio/xcodec2" + + self.tokenizer = AutoTokenizer.from_pretrained(llasa_3b) + self.llasa_3b_model = AutoModelForCausalLM.from_pretrained( + llasa_3b, + trust_remote_code=True, + device_map="auto", + ) + self.llasa_3b_model.eval() + + self.xcodec_model = XCodec2Model.from_pretrained(xcodec2) + self.xcodec_model.eval().cuda() + + # 处理音频 + waveform, sample_rate = torchaudio.load(sample_audio_path) + if len(waveform[0]) / sample_rate > 15: + print("已将音频裁剪至前15秒。") + waveform = waveform[:, : sample_rate * 15] + + # 检查音频是否为立体声 + if waveform.size(0) > 1: + waveform_mono = torch.mean(waveform, dim=0, keepdim=True) + else: + waveform_mono = waveform + + self.prompt_wav = torchaudio.transforms.Resample( + orig_freq=sample_rate, new_freq=16000 + )(waveform_mono) + + # Encode the prompt wav + vq_code_prompt = self.xcodec_model.encode_code(input_waveform=self.prompt_wav) + vq_code_prompt = vq_code_prompt[0, 0, :] + self.speech_ids_prefix = self.ids_to_speech_tokens(vq_code_prompt) + self.speech_end_id = self.tokenizer.convert_tokens_to_ids("<|SPEECH_GENERATION_END|>") + + def ids_to_speech_tokens(self, speech_ids): + speech_tokens_str = [] + for speech_id in speech_ids: + speech_tokens_str.append(f"<|s_{speech_id}|>") + return speech_tokens_str + + def extract_speech_ids(self, speech_tokens_str): + speech_ids = [] + for token_str in speech_tokens_str: + if token_str.startswith("<|s_") and token_str.endswith("|>"): + num_str = token_str[4:-2] + num = int(num_str) + speech_ids.append(num) + else: + print(f"Unexpected token: {token_str}") + return speech_ids + + @torch.inference_mode() + def infer(self, target_text): + if len(target_text) == 0: + return None + elif len(target_text) > 300: + print("文本过长,请保持在300字符以内。") + target_text = target_text[:300] + + input_text = self.sample_audio_text + " " + target_text + + formatted_text = ( + f"<|TEXT_UNDERSTANDING_START|>{input_text}<|TEXT_UNDERSTANDING_END|>" + ) + + chat = [ + { + "role": "user", + "content": "Convert the text to speech:" + formatted_text, + }, + { + "role": "assistant", + "content": "<|SPEECH_GENERATION_START|>" + + "".join(self.speech_ids_prefix), + }, + ] + + input_ids = self.tokenizer.apply_chat_template( + chat, tokenize=True, return_tensors="pt", continue_final_message=True + ) + input_ids = input_ids.to("cuda") + + outputs = self.llasa_3b_model.generate( + input_ids, + max_length=2048, + eos_token_id=self.speech_end_id, + do_sample=True, + top_p=1, + temperature=0.8, + ) + generated_ids = outputs[0][input_ids.shape[1] - len(self.speech_ids_prefix): -1] + + speech_tokens = self.tokenizer.batch_decode( + generated_ids, skip_special_tokens=True + ) + + speech_tokens = self.extract_speech_ids(speech_tokens) + speech_tokens = torch.tensor(speech_tokens).cuda().unsqueeze(0).unsqueeze(0) + + gen_wav = self.xcodec_model.decode_code(speech_tokens) + gen_wav = gen_wav[:, :, self.prompt_wav.shape[1]:] + + return (16000, gen_wav[0, 0, :].cpu().numpy()) + + +if __name__ == "__main__": + # 如果遇到问题,请尝试将参考音频转换为WAV或MP3格式,将其裁剪至15秒以内,并缩短提示文本。 + sample_audio_text = "对,这就是我万人敬仰的太乙真人,虽然有点婴儿肥,但也掩不住我逼人的帅气。" + sample_audio_path = os.path.join(os.path.dirname(__file__), "sample.wav") + + tts = TextToSpeech(sample_audio_path, sample_audio_text) + target_text = "晚上好啊,吃了吗您" + result = tts.infer(target_text) + sf.write(os.path.join(os.path.dirname(__file__), "output.wav"), result[1], result[0]) + target_text = "我是老北京正黄旗!" + result = tts.infer(target_text) + sf.write(os.path.join(os.path.dirname(__file__), "output1.wav"), result[1], result[0]) diff --git a/weclone-audio/src/SparkTTS.py b/weclone-audio/src/SparkTTS.py new file mode 100644 index 0000000000000000000000000000000000000000..76b088affee89cb6c3d5bffeb805402e0892a61f --- /dev/null +++ b/weclone-audio/src/SparkTTS.py @@ -0,0 +1,223 @@ +import re +import torch +from typing import Tuple +from pathlib import Path +from transformers import AutoTokenizer, AutoModelForCausalLM +import os +import sys +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "./Spark-TTS"))) +from sparktts.utils.file import load_config +from sparktts.models.audio_tokenizer import BiCodecTokenizer +from sparktts.utils.token_parser import LEVELS_MAP, GENDER_MAP, TASK_TOKEN_MAP + + +class SparkTTS: + """ + Spark-TTS for text-to-speech generation. + """ + + def __init__(self, model_dir: Path, device: torch.device = torch.device("cuda:0")): + """ + Initializes the SparkTTS model with the provided configurations and device. + + Args: + model_dir (Path): Directory containing the model and config files. + device (torch.device): The device (CPU/GPU) to run the model on. + """ + self.device = device + self.model_dir = model_dir + self.configs = load_config(f"{model_dir}/config.yaml") + self.sample_rate = self.configs["sample_rate"] + self._initialize_inference() + + def _initialize_inference(self): + """Initializes the tokenizer, model, and audio tokenizer for inference.""" + self.tokenizer = AutoTokenizer.from_pretrained(f"{self.model_dir}/LLM") + self.model = AutoModelForCausalLM.from_pretrained(f"{self.model_dir}/LLM") + self.audio_tokenizer = BiCodecTokenizer(self.model_dir, device=self.device) + self.model.to(self.device) + + def process_prompt( + self, + text: str, + prompt_speech_path: Path, + prompt_text: str = None, + ) -> Tuple[str, torch.Tensor]: + """ + Process input for voice cloning. + + Args: + text (str): The text input to be converted to speech. + prompt_speech_path (Path): Path to the audio file used as a prompt. + prompt_text (str, optional): Transcript of the prompt audio. + + Return: + Tuple[str, torch.Tensor]: Input prompt; global tokens + """ + + global_token_ids, semantic_token_ids = self.audio_tokenizer.tokenize( + prompt_speech_path + ) + global_tokens = "".join( + [f"<|bicodec_global_{i}|>" for i in global_token_ids.squeeze()] + ) + + # Prepare the input tokens for the model + if prompt_text is not None: + semantic_tokens = "".join( + [f"<|bicodec_semantic_{i}|>" for i in semantic_token_ids.squeeze()] + ) + inputs = [ + TASK_TOKEN_MAP["tts"], + "<|start_content|>", + prompt_text, + text, + "<|end_content|>", + "<|start_global_token|>", + global_tokens, + "<|end_global_token|>", + "<|start_semantic_token|>", + semantic_tokens, + ] + else: + inputs = [ + TASK_TOKEN_MAP["tts"], + "<|start_content|>", + text, + "<|end_content|>", + "<|start_global_token|>", + global_tokens, + "<|end_global_token|>", + ] + + inputs = "".join(inputs) + + return inputs, global_token_ids + + def process_prompt_control( + self, + gender: str, + pitch: str, + speed: str, + text: str, + ): + """ + Process input for voice creation. + + Args: + gender (str): female | male. + pitch (str): very_low | low | moderate | high | very_high + speed (str): very_low | low | moderate | high | very_high + text (str): The text input to be converted to speech. + + Return: + str: Input prompt + """ + assert gender in GENDER_MAP.keys() + assert pitch in LEVELS_MAP.keys() + assert speed in LEVELS_MAP.keys() + + gender_id = GENDER_MAP[gender] + pitch_level_id = LEVELS_MAP[pitch] + speed_level_id = LEVELS_MAP[speed] + + pitch_label_tokens = f"<|pitch_label_{pitch_level_id}|>" + speed_label_tokens = f"<|speed_label_{speed_level_id}|>" + gender_tokens = f"<|gender_{gender_id}|>" + + attribte_tokens = "".join( + [gender_tokens, pitch_label_tokens, speed_label_tokens] + ) + + control_tts_inputs = [ + TASK_TOKEN_MAP["controllable_tts"], + "<|start_content|>", + text, + "<|end_content|>", + "<|start_style_label|>", + attribte_tokens, + "<|end_style_label|>", + ] + + return "".join(control_tts_inputs) + + @torch.no_grad() + def inference( + self, + text: str, + prompt_speech_path: Path = None, + prompt_text: str = None, + gender: str = None, + pitch: str = None, + speed: str = None, + temperature: float = 0.8, + top_k: float = 50, + top_p: float = 0.95, + ) -> torch.Tensor: + """ + Performs inference to generate speech from text, incorporating prompt audio and/or text. + + Args: + text (str): The text input to be converted to speech. + prompt_speech_path (Path): Path to the audio file used as a prompt. + prompt_text (str, optional): Transcript of the prompt audio. + gender (str): female | male. + pitch (str): very_low | low | moderate | high | very_high + speed (str): very_low | low | moderate | high | very_high + temperature (float, optional): Sampling temperature for controlling randomness. Default is 0.8. + top_k (float, optional): Top-k sampling parameter. Default is 50. + top_p (float, optional): Top-p (nucleus) sampling parameter. Default is 0.95. + + Returns: + torch.Tensor: Generated waveform as a tensor. + """ + if gender is not None: + prompt = self.process_prompt_control(gender, pitch, speed, text) + + else: + prompt, global_token_ids = self.process_prompt( + text, prompt_speech_path, prompt_text + ) + model_inputs = self.tokenizer([prompt], return_tensors="pt").to(self.device) + + # Generate speech using the model + generated_ids = self.model.generate( + **model_inputs, + max_new_tokens=3000, + do_sample=True, + top_k=top_k, + top_p=top_p, + temperature=temperature, + ) + + # Trim the output tokens to remove the input tokens + generated_ids = [ + output_ids[len(input_ids):] + for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids) + ] + + # Decode the generated tokens into text + predicts = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0] + + # Extract semantic token IDs from the generated text + pred_semantic_ids = ( + torch.tensor([int(token) for token in re.findall(r"bicodec_semantic_(\d+)", predicts)]) + .long() + .unsqueeze(0) + ) + + if gender is not None: + global_token_ids = ( + torch.tensor([int(token) for token in re.findall(r"bicodec_global_(\d+)", predicts)]) + .long() + .unsqueeze(0) + .unsqueeze(0) + ) + + # Convert semantic tokens back to waveform + wav = self.audio_tokenizer.detokenize( + global_token_ids.to(self.device).squeeze(0), + pred_semantic_ids.to(self.device), + ) + + return wav diff --git a/weclone-audio/src/__init__.py b/weclone-audio/src/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/weclone-audio/src/get_sample_audio.py b/weclone-audio/src/get_sample_audio.py new file mode 100644 index 0000000000000000000000000000000000000000..bb710b211115a70512163a84a68f60d43f9170bb --- /dev/null +++ b/weclone-audio/src/get_sample_audio.py @@ -0,0 +1,35 @@ +import os +import argparse +from pywxdump.db import MediaHandler + +def main(): + parser = argparse.ArgumentParser(description="Extract audio from WeChat database") + parser.add_argument("--db-path", type=str, required=True, + help="Path to WeChat database file") + parser.add_argument("--MsgSvrID", type=str, required=True, + help="Message server ID of the audio") + parser.add_argument("--save-path", type=str, + default=os.path.join(os.path.dirname(__file__), "sample.wav"), + help="Path to save the audio file (default: sample.wav in script directory)") + parser.add_argument("--rate", type=int, default=24000, + help="Sample rate for audio conversion (default: 24000)") + + args = parser.parse_args() + + config = { + "key": "test1", + "type": "sqlite", + "path": args.db_path, + } + + t1 = MediaHandler(config) + t1.get_audio( + MsgSvrID=args.MsgSvrID, + is_play=True, + is_wave=True, + save_path=args.save_path, + rate=args.rate, + ) + +if __name__ == "__main__": + main() diff --git a/weclone-audio/src/infer.py b/weclone-audio/src/infer.py new file mode 100644 index 0000000000000000000000000000000000000000..29610b5becb1027d77781e14a6b7e02cb3048b0f --- /dev/null +++ b/weclone-audio/src/infer.py @@ -0,0 +1,17 @@ +import os +import soundfile as sf +import torch + +from SparkTTS import SparkTTS + +model = SparkTTS("weclone-audio/pretrained_models/Spark-TTS-0.5B", "cuda") + + +with torch.no_grad(): + wav = model.inference( + text="晚上好啊,小可爱们,该睡觉了哦", + prompt_speech_path=os.path.join(os.path.dirname(__file__), "sample.wav"), + prompt_text="对,这就是我万人敬仰的太乙真人,虽然有点婴儿肥,但也掩不住我逼人的帅气。", + ) + sf.write(os.path.join(os.path.dirname(__file__), "output.wav"), wav, samplerate=16000) + print("生成成功!") diff --git a/weclone-audio/src/sample.wav b/weclone-audio/src/sample.wav new file mode 100644 index 0000000000000000000000000000000000000000..89f432d08f7e1a9de2ad61b0ad8e57bd27195859 --- /dev/null +++ b/weclone-audio/src/sample.wav @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:014954ddd00ec481f993ca65c904b8f3ff426df1be05ca260e2b03b3e892fc1b +size 412402 diff --git "a/weclone-audio/src/server\346\234\252\345\256\214\345\267\245/.env.example" "b/weclone-audio/src/server\346\234\252\345\256\214\345\267\245/.env.example" new file mode 100644 index 0000000000000000000000000000000000000000..cbcd9a26b0aadd2a61e28faf601b60464eaaa22d --- /dev/null +++ "b/weclone-audio/src/server\346\234\252\345\256\214\345\267\245/.env.example" @@ -0,0 +1,14 @@ +API_KEY=your_api_key_here +PORT=5050 + +DEFAULT_VOICE=en-US-AvaNeural +DEFAULT_RESPONSE_FORMAT=mp3 +DEFAULT_SPEED=1.0 + +DEFAULT_LANGUAGE=en-US + +REQUIRE_API_KEY=True + +REMOVE_FILTER=False + +EXPAND_API=True \ No newline at end of file diff --git "a/weclone-audio/src/server\346\234\252\345\256\214\345\267\245/handle_text.py" "b/weclone-audio/src/server\346\234\252\345\256\214\345\267\245/handle_text.py" new file mode 100644 index 0000000000000000000000000000000000000000..dac452acdab3cdd7c5f1c5f199d84d9372210b4a --- /dev/null +++ "b/weclone-audio/src/server\346\234\252\345\256\214\345\267\245/handle_text.py" @@ -0,0 +1,62 @@ +import re +import emoji + +def prepare_tts_input_with_context(text: str) -> str: + """ + Prepares text for a TTS API by cleaning Markdown and adding minimal contextual hints + for certain Markdown elements like headers. Preserves paragraph separation. + + Args: + text (str): The raw text containing Markdown or other formatting. + + Returns: + str: Cleaned text with contextual hints suitable for TTS input. + """ + + # Remove emojis + text = emoji.replace_emoji(text, replace='') + + # Add context for headers + def header_replacer(match): + level = len(match.group(1)) # Number of '#' symbols + header_text = match.group(2).strip() + if level == 1: + return f"Title — {header_text}\n" + elif level == 2: + return f"Section — {header_text}\n" + else: + return f"Subsection — {header_text}\n" + + text = re.sub(r"^(#{1,6})\s+(.*)", header_replacer, text, flags=re.MULTILINE) + + # Announce links (currently commented out for potential future use) + # text = re.sub(r"\[([^\]]+)\]\((https?:\/\/[^\)]+)\)", r"\1 (link: \2)", text) + + # Remove links while keeping the link text + text = re.sub(r"\[([^\]]+)\]\([^\)]+\)", r"\1", text) + + # Describe inline code + text = re.sub(r"`([^`]+)`", r"code snippet: \1", text) + + # Remove bold/italic symbols but keep the content + text = re.sub(r"(\*\*|__|\*|_)", '', text) + + # Remove code blocks (multi-line) with a description + text = re.sub(r"```([\s\S]+?)```", r"(code block omitted)", text) + + # Remove image syntax but add alt text if available + text = re.sub(r"!\[([^\]]*)\]\([^\)]+\)", r"Image: \1", text) + + # Remove HTML tags + text = re.sub(r"]+(>|$)", '', text) + + # Normalize line breaks + text = re.sub(r"\n{2,}", '\n\n', text) # Ensure consistent paragraph separation + + # Replace multiple spaces within lines + text = re.sub(r" {2,}", ' ', text) + + # Trim leading and trailing whitespace from the whole text + text = text.strip() + + return text diff --git "a/weclone-audio/src/server\346\234\252\345\256\214\345\267\245/requirements.txt" "b/weclone-audio/src/server\346\234\252\345\256\214\345\267\245/requirements.txt" new file mode 100644 index 0000000000000000000000000000000000000000..2f98e40ee97cc78dfd210f9de53703d84aae0a81 --- /dev/null +++ "b/weclone-audio/src/server\346\234\252\345\256\214\345\267\245/requirements.txt" @@ -0,0 +1,5 @@ +flask +gevent +python-dotenv +edge-tts +emoji \ No newline at end of file diff --git "a/weclone-audio/src/server\346\234\252\345\256\214\345\267\245/server.py" "b/weclone-audio/src/server\346\234\252\345\256\214\345\267\245/server.py" new file mode 100644 index 0000000000000000000000000000000000000000..b107f2cf56e38f034b80f30367985a84bee0bfc3 --- /dev/null +++ "b/weclone-audio/src/server\346\234\252\345\256\214\345\267\245/server.py" @@ -0,0 +1,167 @@ +# server.py + +from flask import Flask, request, send_file, jsonify +from gevent.pywsgi import WSGIServer +from dotenv import load_dotenv +import os + +from handle_text import prepare_tts_input_with_context +from tts_handler import generate_speech, get_models, get_voices +from utils import getenv_bool, require_api_key, AUDIO_FORMAT_MIME_TYPES + +app = Flask(__name__) +load_dotenv() + +API_KEY = os.getenv('API_KEY', 'your_api_key_here') +PORT = int(os.getenv('PORT', 5050)) + +DEFAULT_VOICE = os.getenv('DEFAULT_VOICE', 'en-US-AvaNeural') +DEFAULT_RESPONSE_FORMAT = os.getenv('DEFAULT_RESPONSE_FORMAT', 'mp3') +DEFAULT_SPEED = float(os.getenv('DEFAULT_SPEED', 1.0)) + +REMOVE_FILTER = getenv_bool('REMOVE_FILTER', False) +EXPAND_API = getenv_bool('EXPAND_API', True) + +# DEFAULT_MODEL = os.getenv('DEFAULT_MODEL', 'tts-1') + +@app.route('/v1/audio/speech', methods=['POST']) +@app.route('/audio/speech', methods=['POST']) # Add this line for the alias +@require_api_key +def text_to_speech(): + data = request.json + if not data or 'input' not in data: + return jsonify({"error": "Missing 'input' in request body"}), 400 + + text = data.get('input') + + if not REMOVE_FILTER: + text = prepare_tts_input_with_context(text) + + # model = data.get('model', DEFAULT_MODEL) + voice = data.get('voice', DEFAULT_VOICE) + + response_format = data.get('response_format', DEFAULT_RESPONSE_FORMAT) + speed = float(data.get('speed', DEFAULT_SPEED)) + + mime_type = AUDIO_FORMAT_MIME_TYPES.get(response_format, "audio/mpeg") + + # Generate the audio file in the specified format with speed adjustment + output_file_path = generate_speech(text, voice, response_format, speed) + + # Return the file with the correct MIME type + return send_file(output_file_path, mimetype=mime_type, as_attachment=True, download_name=f"speech.{response_format}") + +@app.route('/v1/models', methods=['GET', 'POST']) +@app.route('/models', methods=['GET', 'POST']) +@require_api_key +def list_models(): + return jsonify({"data": get_models()}) + +@app.route('/v1/voices', methods=['GET', 'POST']) +@app.route('/voices', methods=['GET', 'POST']) +@require_api_key +def list_voices(): + specific_language = None + + data = request.args if request.method == 'GET' else request.json + if data and ('language' in data or 'locale' in data): + specific_language = data.get('language') if 'language' in data else data.get('locale') + + return jsonify({"voices": get_voices(specific_language)}) + +@app.route('/v1/voices/all', methods=['GET', 'POST']) +@app.route('/voices/all', methods=['GET', 'POST']) +@require_api_key +def list_all_voices(): + return jsonify({"voices": get_voices('all')}) + +""" +Support for ElevenLabs and Azure AI Speech + (currently in beta) +""" + +# http://localhost:5050/elevenlabs/v1/text-to-speech +# http://localhost:5050/elevenlabs/v1/text-to-speech/en-US-AndrewNeural +@app.route('/elevenlabs/v1/text-to-speech/', methods=['POST']) +@require_api_key +def elevenlabs_tts(voice_id): + if not EXPAND_API: + return jsonify({"error": f"Endpoint not allowed"}), 500 + + # Parse the incoming JSON payload + try: + payload = request.json + if not payload or 'text' not in payload: + return jsonify({"error": "Missing 'text' in request body"}), 400 + except Exception as e: + return jsonify({"error": f"Invalid JSON payload: {str(e)}"}), 400 + + text = payload['text'] + + if not REMOVE_FILTER: + text = prepare_tts_input_with_context(text) + + voice = voice_id # ElevenLabs uses the voice_id in the URL + + # Use default settings for edge-tts + response_format = 'mp3' + speed = DEFAULT_SPEED # Optional customization via payload.get('speed', DEFAULT_SPEED) + + # Generate speech using edge-tts + try: + output_file_path = generate_speech(text, voice, response_format, speed) + except Exception as e: + return jsonify({"error": f"TTS generation failed: {str(e)}"}), 500 + + # Return the generated audio file + return send_file(output_file_path, mimetype="audio/mpeg", as_attachment=True, download_name="speech.mp3") + +# tts.speech.microsoft.com/cognitiveservices/v1 +# https://{region}.tts.speech.microsoft.com/cognitiveservices/v1 +# http://localhost:5050/azure/cognitiveservices/v1 +@app.route('/azure/cognitiveservices/v1', methods=['POST']) +@require_api_key +def azure_tts(): + if not EXPAND_API: + return jsonify({"error": f"Endpoint not allowed"}), 500 + + # Parse the SSML payload + try: + ssml_data = request.data.decode('utf-8') + if not ssml_data: + return jsonify({"error": "Missing SSML payload"}), 400 + + # Extract the text and voice from SSML + from xml.etree import ElementTree as ET + root = ET.fromstring(ssml_data) + text = root.find('.//{http://www.w3.org/2001/10/synthesis}voice').text + voice = root.find('.//{http://www.w3.org/2001/10/synthesis}voice').get('name') + except Exception as e: + return jsonify({"error": f"Invalid SSML payload: {str(e)}"}), 400 + + # Use default settings for edge-tts + response_format = 'mp3' + speed = DEFAULT_SPEED + + if not REMOVE_FILTER: + text = prepare_tts_input_with_context(text) + + # Generate speech using edge-tts + try: + output_file_path = generate_speech(text, voice, response_format, speed) + except Exception as e: + return jsonify({"error": f"TTS generation failed: {str(e)}"}), 500 + + # Return the generated audio file + return send_file(output_file_path, mimetype="audio/mpeg", as_attachment=True, download_name="speech.mp3") + +print(f" Edge TTS (Free Azure TTS) Replacement for OpenAI's TTS API") +print(f" ") +print(f" * Serving OpenAI Edge TTS") +print(f" * Server running on http://localhost:{PORT}") +print(f" * TTS Endpoint: http://localhost:{PORT}/v1/audio/speech") +print(f" ") + +if __name__ == '__main__': + http_server = WSGIServer(('0.0.0.0', PORT), app) + http_server.serve_forever() diff --git "a/weclone-audio/src/server\346\234\252\345\256\214\345\267\245/tts_handler.py" "b/weclone-audio/src/server\346\234\252\345\256\214\345\267\245/tts_handler.py" new file mode 100644 index 0000000000000000000000000000000000000000..bb2075c19f2a8a6140a53763d38466a509a06af9 --- /dev/null +++ "b/weclone-audio/src/server\346\234\252\345\256\214\345\267\245/tts_handler.py" @@ -0,0 +1,133 @@ +import edge_tts +import asyncio +import tempfile +import subprocess +import os +from pathlib import Path + +# Language default (environment variable) +DEFAULT_LANGUAGE = os.getenv('DEFAULT_LANGUAGE', 'en-US') + +# OpenAI voice names mapped to edge-tts equivalents +voice_mapping = { + 'alloy': 'en-US-AvaNeural', + 'echo': 'en-US-AndrewNeural', + 'fable': 'en-GB-SoniaNeural', + 'onyx': 'en-US-EricNeural', + 'nova': 'en-US-SteffanNeural', + 'shimmer': 'en-US-EmmaNeural' +} + +def is_ffmpeg_installed(): + """Check if FFmpeg is installed and accessible.""" + try: + subprocess.run(['ffmpeg', '-version'], check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + return True + except (subprocess.CalledProcessError, FileNotFoundError): + return False + +async def _generate_audio(text, voice, response_format, speed): + """Generate TTS audio and optionally convert to a different format.""" + # Determine if the voice is an OpenAI-compatible voice or a direct edge-tts voice + edge_tts_voice = voice_mapping.get(voice, voice) # Use mapping if in OpenAI names, otherwise use as-is + + # Generate the TTS output in mp3 format first + temp_output_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") + + # Convert speed to SSML rate format + try: + speed_rate = speed_to_rate(speed) # Convert speed value to "+X%" or "-X%" + except Exception as e: + print(f"Error converting speed: {e}. Defaulting to +0%.") + speed_rate = "+0%" + + # Generate the MP3 file + communicator = edge_tts.Communicate(text=text, voice=edge_tts_voice, rate=speed_rate) + await communicator.save(temp_output_file.name) + + # If the requested format is mp3, return the generated file directly + if response_format == "mp3": + return temp_output_file.name + + # Check if FFmpeg is installed + if not is_ffmpeg_installed(): + print("FFmpeg is not available. Returning unmodified mp3 file.") + return temp_output_file.name + + # Create a new temporary file for the converted output + converted_output_file = tempfile.NamedTemporaryFile(delete=False, suffix=f".{response_format}") + + # Build the FFmpeg command + ffmpeg_command = [ + "ffmpeg", + "-i", temp_output_file.name, # Input file + "-c:a", { + "aac": "aac", + "mp3": "libmp3lame", + "wav": "pcm_s16le", + "opus": "libopus", + "flac": "flac" + }.get(response_format, "aac"), # Default to AAC if unknown + "-b:a", "192k" if response_format != "wav" else None, # Bitrate not needed for WAV + "-f", { + "aac": "mp4", # AAC in MP4 container + "mp3": "mp3", + "wav": "wav", + "opus": "ogg", + "flac": "flac" + }.get(response_format, response_format), # Default to matching format + "-y", # Overwrite without prompt + converted_output_file.name # Output file + ] + + try: + # Run FFmpeg command and ensure no errors occur + subprocess.run(ffmpeg_command, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + except subprocess.CalledProcessError as e: + raise RuntimeError(f"FFmpeg error during audio conversion: {e}") + + # Clean up the original temporary file + Path(temp_output_file.name).unlink(missing_ok=True) + + return converted_output_file.name + +def generate_speech(text, voice, response_format, speed=1.0): + return asyncio.run(_generate_audio(text, voice, response_format, speed)) + +def get_models(): + return [ + {"id": "tts-1", "name": "Text-to-speech v1"}, + {"id": "tts-1-hd", "name": "Text-to-speech v1 HD"} + ] + +async def _get_voices(language=None): + # List all voices, filter by language if specified + all_voices = await edge_tts.list_voices() + language = language or DEFAULT_LANGUAGE # Use default if no language specified + filtered_voices = [ + {"name": v['ShortName'], "gender": v['Gender'], "language": v['Locale']} + for v in all_voices if language == 'all' or language is None or v['Locale'] == language + ] + return filtered_voices + +def get_voices(language=None): + return asyncio.run(_get_voices(language)) + +def speed_to_rate(speed: float) -> str: + """ + Converts a multiplicative speed value to the edge-tts "rate" format. + + Args: + speed (float): The multiplicative speed value (e.g., 1.5 for +50%, 0.5 for -50%). + + Returns: + str: The formatted "rate" string (e.g., "+50%" or "-50%"). + """ + if speed < 0 or speed > 2: + raise ValueError("Speed must be between 0 and 2 (inclusive).") + + # Convert speed to percentage change + percentage_change = (speed - 1) * 100 + + # Format with a leading "+" or "-" as required + return f"{percentage_change:+.0f}%" diff --git "a/weclone-audio/src/server\346\234\252\345\256\214\345\267\245/utils.py" "b/weclone-audio/src/server\346\234\252\345\256\214\345\267\245/utils.py" new file mode 100644 index 0000000000000000000000000000000000000000..7ad26b5a685b57999dc0af0acee4cd5bef0b5482 --- /dev/null +++ "b/weclone-audio/src/server\346\234\252\345\256\214\345\267\245/utils.py" @@ -0,0 +1,38 @@ +# utils.py + +from flask import request, jsonify +from functools import wraps +import os +from dotenv import load_dotenv + +load_dotenv() + +def getenv_bool(name: str, default: bool = False) -> bool: + return os.getenv(name, str(default)).lower() in ("yes", "y", "true", "1", "t") + +API_KEY = os.getenv('API_KEY', 'your_api_key_here') +REQUIRE_API_KEY = getenv_bool('REQUIRE_API_KEY', True) + +def require_api_key(f): + @wraps(f) + def decorated_function(*args, **kwargs): + if not REQUIRE_API_KEY: + return f(*args, **kwargs) + auth_header = request.headers.get('Authorization') + if not auth_header or not auth_header.startswith('Bearer '): + return jsonify({"error": "Missing or invalid API key"}), 401 + token = auth_header.split('Bearer ')[1] + if token != API_KEY: + return jsonify({"error": "Invalid API key"}), 401 + return f(*args, **kwargs) + return decorated_function + +# Mapping of audio format to MIME type +AUDIO_FORMAT_MIME_TYPES = { + "mp3": "audio/mpeg", + "opus": "audio/ogg", + "aac": "audio/aac", + "flac": "audio/flac", + "wav": "audio/wav", + "pcm": "audio/L16" +} diff --git a/weclone/__init__.py b/weclone/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/weclone/cli.py b/weclone/cli.py new file mode 100644 index 0000000000000000000000000000000000000000..fe470aed8d3676415f45f50e5fe5bbcabaa798ee --- /dev/null +++ b/weclone/cli.py @@ -0,0 +1,207 @@ +import click +import commentjson +from pathlib import Path +import os +import sys +import functools + +from weclone.utils.log import logger, capture_output +from weclone.utils.config import load_config + +cli_config: dict | None = None + +try: + import tomllib # type: ignore Python 3.11+ +except ImportError: + import tomli as tomllib + + +def clear_argv(func): + """ + 装饰器:在调用被装饰函数前,清理 sys.argv,只保留脚本名。调用后恢复原始 sys.argv。 + 用于防止参数被 Hugging Face HfArgumentParser 解析造成 ValueError。 + """ + + @functools.wraps(func) + def wrapper(*args, **kwargs): + original_argv = sys.argv.copy() + sys.argv = [original_argv[0]] # 只保留脚本名 + try: + return func(*args, **kwargs) + finally: + sys.argv = original_argv # 恢复原始 sys.argv + + return wrapper + + +def apply_common_decorators(capture_output_enabled=False): + """ + A unified decorator for applications + """ + + def decorator(original_cmd_func): + @functools.wraps(original_cmd_func) + def new_runtime_wrapper(*args, **kwargs): + if cli_config and cli_config.get("full_log", False): + return capture_output(original_cmd_func)(*args, **kwargs) + else: + return original_cmd_func(*args, **kwargs) + + func_with_clear_argv = clear_argv(new_runtime_wrapper) + + return functools.wraps(original_cmd_func)(func_with_clear_argv) + + return decorator + + +@click.group() +def cli(): + """WeClone: 从聊天记录创造数字分身的一站式解决方案""" + _check_project_root() + _check_versions() + global cli_config + cli_config = load_config(arg_type="cli_args") + + +@cli.command("make-dataset", help="处理聊天记录CSV文件,生成问答对数据集。") +@apply_common_decorators() +def qa_generator(): + """处理聊天记录CSV文件,生成问答对数据集。""" + from weclone.data.qa_generator import DataProcessor + + processor = DataProcessor() + processor.main() + + +@cli.command("train-sft", help="使用准备好的数据集对模型进行微调。") +@apply_common_decorators() +def train_sft(): + """使用准备好的数据集对模型进行微调。""" + from weclone.train.train_sft import main as train_sft_main + + train_sft_main() + + +@cli.command("webchat-demo", help="启动 Web UI 与微调后的模型进行交互测试。") # 命令名修改为 web-demo +@apply_common_decorators() +def web_demo(): + """启动 Web UI 与微调后的模型进行交互测试。""" + from weclone.eval.web_demo import main as web_demo_main + + web_demo_main() + + +# TODO 添加评估功能 @cli.command("eval-model", help="使用从训练数据中划分出来的验证集评估。") +@apply_common_decorators() +def eval_model(): + """使用从训练数据中划分出来的验证集评估。""" + from weclone.eval.eval_model import main as evaluate_main + + evaluate_main() + + +@cli.command("test-model", help="使用常见聊天问题测试模型。") +@apply_common_decorators() +def test_model(): + """测试""" + from weclone.eval.test_model import main as test_main + + test_main() + + +@cli.command("server", help="启动API服务,提供模型推理接口。") +@apply_common_decorators() +def server(): + """启动API服务,提供模型推理接口。""" + from weclone.server.api_service import main as server_main + + server_main() + + +def _check_project_root(): + """检查当前目录是否为项目根目录,并验证项目名称。""" + project_root_marker = "pyproject.toml" + current_dir = Path(os.getcwd()) + pyproject_path = current_dir / project_root_marker + + if not pyproject_path.is_file(): + logger.error(f"未在当前目录找到 {project_root_marker} 文件。") + logger.error("请确保在WeClone项目根目录下运行此命令。") + sys.exit(1) + + try: + with open(pyproject_path, "rb") as f: + pyproject_data = tomllib.load(f) + project_name = pyproject_data.get("project", {}).get("name") + if project_name != "WeClone": + logger.error("请确保在正确的 WeClone 项目根目录下运行。") + sys.exit(1) + except tomllib.TOMLDecodeError as e: + logger.error(f"错误:无法解析 {pyproject_path} 文件: {e}") + sys.exit(1) + except Exception as e: + logger.error(f"读取或处理 {pyproject_path} 时发生意外错误: {e}") + sys.exit(1) + + +def _check_versions(): + """比较本地 settings.jsonc 版本和 pyproject.toml 中的配置文件指南版本""" + if tomllib is None: # Skip check if toml parser failed to import + return + + ROOT_DIR = Path(__file__).parent.parent + SETTINGS_PATH = ROOT_DIR / "settings.jsonc" + PYPROJECT_PATH = ROOT_DIR / "pyproject.toml" + + settings_version = None + config_guide_version = None + config_changelog = None + + if SETTINGS_PATH.exists(): + try: + with open(SETTINGS_PATH, "r", encoding="utf-8") as f: + settings_data = commentjson.load(f) + settings_version = settings_data.get("version") + except Exception as e: + logger.error(f"错误:无法读取或解析 {SETTINGS_PATH}: {e}") + logger.error("请确保 settings.jsonc 文件存在且格式正确。") + sys.exit(1) + else: + logger.error(f"错误:未找到配置文件 {SETTINGS_PATH}。") + logger.error("请确保 settings.jsonc 文件位于项目根目录。") + sys.exit(1) + + if PYPROJECT_PATH.exists(): + try: + with open(PYPROJECT_PATH, "rb") as f: # tomllib 需要二进制模式 + pyproject_data = tomllib.load(f) + weclone_tool_data = pyproject_data.get("tool", {}).get("weclone", {}) + config_guide_version = weclone_tool_data.get("config_version") + config_changelog = weclone_tool_data.get("config_changelog", "N/A") + except Exception as e: + logger.warning(f"警告:无法读取或解析 {PYPROJECT_PATH}: {e}。无法检查配置文件是否为最新。") + else: + logger.warning(f"警告:未找到文件 {PYPROJECT_PATH}。无法检查配置文件是否为最新。") + + if not settings_version: + logger.error(f"错误:在 {SETTINGS_PATH} 中未找到 'version' 字段。") + logger.error("请从 settings.template.json 复制或更新您的 settings.jsonc 文件。") + sys.exit(1) + + if config_guide_version: + if settings_version != config_guide_version: + logger.warning( + f"警告:您的 settings.jsonc 文件版本 ({settings_version}) 与项目建议的配置版本 ({config_guide_version}) 不一致。" + ) + logger.warning("这可能导致意外行为或错误。请从 settings.template.json 复制或更新您的 settings.jsonc 文件。") + # TODO 根据版本号打印更新日志 + logger.warning(f"配置文件更新日志:\n{config_changelog}") + elif PYPROJECT_PATH.exists(): # 如果文件存在但未读到版本 + logger.warning( + f"警告:在 {PYPROJECT_PATH} 的 [tool.weclone] 下未找到 'config_version' 字段。" + "无法确认您的 settings.jsonc 是否为最新配置版本。" + ) + + +if __name__ == "__main__": + cli() diff --git a/weclone/core/inference/offline_infer.py b/weclone/core/inference/offline_infer.py new file mode 100644 index 0000000000000000000000000000000000000000..6b133f3454429f05debd7cc8beab07e9c7902de2 --- /dev/null +++ b/weclone/core/inference/offline_infer.py @@ -0,0 +1,120 @@ +import json +from typing import List, Optional, Union + + +from llamafactory.data import get_dataset, get_template_and_fix_tokenizer +from llamafactory.extras.constants import IGNORE_INDEX +from llamafactory.extras.misc import get_device_count +from llamafactory.extras.packages import is_vllm_available +from llamafactory.hparams import get_infer_args +from llamafactory.model import load_tokenizer +from pydantic import BaseModel +from vllm.sampling_params import GuidedDecodingParams + + +from vllm import LLM, SamplingParams +from vllm.lora.request import LoRARequest + + +# 这里不需要写太好,transforms库后续更新自带vllm + + +def vllm_infer( + inputs: Union[str, List[str]], + model_name_or_path: str, + adapter_name_or_path: Optional[str] = None, + dataset: str = "alpaca_en_demo", + dataset_dir: str = "data", + template: str = "default", + cutoff_len: int = 2048, + max_samples: Optional[int] = None, + vllm_config: str = "{}", + save_name: str = "generated_predictions.jsonl", + temperature: float = 0.95, + top_p: float = 0.7, + top_k: int = 50, + guided_decoding_class: Optional[type[BaseModel]] = None, + bad_words: Optional[List[str]] = None, + logprobs: Optional[int] = None, + max_new_tokens: int = 1024, + repetition_penalty: float = 1.0, + skip_special_tokens: bool = True, + seed: Optional[int] = None, + pipeline_parallel_size: int = 1, + image_max_pixels: int = 768 * 768, + image_min_pixels: int = 32 * 32, +): + r"""Perform batch generation using vLLM engine, which supports tensor parallelism.""" + if pipeline_parallel_size > get_device_count(): + raise ValueError("Pipeline parallel size should be smaller than the number of gpus.") + + model_args, data_args, _, generating_args = get_infer_args( + dict( + model_name_or_path=model_name_or_path, + adapter_name_or_path=adapter_name_or_path, + dataset=dataset, + dataset_dir=dataset_dir, + template=template, + cutoff_len=cutoff_len, + max_samples=max_samples, + preprocessing_num_workers=16, + vllm_config=vllm_config, + temperature=temperature, + top_p=top_p, + top_k=top_k, + max_new_tokens=max_new_tokens, + repetition_penalty=repetition_penalty, + ) + ) + + tokenizer_module = load_tokenizer(model_args) + tokenizer = tokenizer_module["tokenizer"] + template_obj = get_template_and_fix_tokenizer(tokenizer, data_args) + template_obj.mm_plugin.expand_mm_tokens = False # for vllm generate + + if guided_decoding_class: + json_schema = guided_decoding_class.model_json_schema() + guided_decoding_params = GuidedDecodingParams(json=json_schema) + else: + guided_decoding_params = None + + sampling_params = SamplingParams( + repetition_penalty=generating_args.repetition_penalty or 1.0, # repetition_penalty must > 0 + temperature=generating_args.temperature, + top_p=generating_args.top_p or 1.0, # top_p must > 0 + top_k=generating_args.top_k or -1, # top_k must > 0 + stop_token_ids=template_obj.get_stop_token_ids(tokenizer), + max_tokens=generating_args.max_new_tokens, + skip_special_tokens=skip_special_tokens, + seed=seed, + guided_decoding=guided_decoding_params, + bad_words=bad_words, + ) + if model_args.adapter_name_or_path is not None: + lora_request = LoRARequest("default", 1, model_args.adapter_name_or_path[0]) + else: + lora_request = None + + engine_args = { + "model": model_args.model_name_or_path, + "trust_remote_code": True, + "dtype": model_args.infer_dtype, + "max_model_len": cutoff_len + max_new_tokens, + # "tensor_parallel_size": 1, + # "pipeline_parallel_size": pipeline_parallel_size, + # "data_parallel_size": get_device_count(), // vllm0.8.5版本支持DP + "disable_log_stats": True, + "enable_lora": model_args.adapter_name_or_path is not None, + "enable_prefix_caching": True, # 是否启用前缀缓存 + "gpu_memory_utilization": 0.95, + # "quantization": "bitsandbytes", # 是否启用vllm的 bitsandbytes 的量化加载 + # "load_format": "bitsandbytes", + } + if template_obj.mm_plugin.__class__.__name__ != "BasePlugin": + engine_args["limit_mm_per_prompt"] = {"image": 4, "video": 2, "audio": 2} + + if isinstance(model_args.vllm_config, dict): + engine_args.update(model_args.vllm_config) + + results = LLM(**engine_args).generate(inputs, sampling_params, lora_request=lora_request) + return results diff --git a/weclone/core/inference/online_infer.py b/weclone/core/inference/online_infer.py new file mode 100644 index 0000000000000000000000000000000000000000..972d82262771f11b3ee0fd26ad1aba8141e43e8a --- /dev/null +++ b/weclone/core/inference/online_infer.py @@ -0,0 +1,40 @@ +import json +import time +import requests +from openai import OpenAI + +class OnlineLLM: + def __init__(self, api_key: str, base_url: str,model_name: str,default_system: str): + self.api_key = api_key + self.base_url = base_url + self.model_name = model_name + self.default_system = default_system + self.client = OpenAI( + api_key=self.api_key, + base_url=self.base_url + ) + + + def chat(self,prompt_text, + temperature: float = 0.7, + max_tokens: int = 1024, + top_p: float = 0.95, + stream: bool = False, + enable_thinking: bool = False): + messages = [ + {"role": "system", "content": self.default_system}, + {"role": "user", "content": prompt_text}, + ] + response = self.client.chat.completions.create( + model=self.model_name, + messages=messages, + stream=stream, + temperature = temperature, + max_tokens=max_tokens, + top_p=top_p, + # enable_thinking=enable_thinking 适配Qwen3动态开启推理 + + ) + + return response + diff --git a/weclone/data/__init__.py b/weclone/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/weclone/data/chat_parsers/wechat_parser.py b/weclone/data/chat_parsers/wechat_parser.py new file mode 100644 index 0000000000000000000000000000000000000000..a9862d665f2e7ed4be796c7aae26478c695fbf5d --- /dev/null +++ b/weclone/data/chat_parsers/wechat_parser.py @@ -0,0 +1,8 @@ +class WeChatParser: + def decrypt_wechat_image(self, encrypted_path, output_path): + """解密微信加密的图片文件""" + pass + + def parse_chat_records(self, db_path): + """解析聊天记录数据库""" + pass diff --git a/weclone/data/clean/__init__.py b/weclone/data/clean/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/weclone/data/clean/get_score.py b/weclone/data/clean/get_score.py new file mode 100644 index 0000000000000000000000000000000000000000..3f3f2bcdae827cf39d1ecf611b952f4161e8e057 --- /dev/null +++ b/weclone/data/clean/get_score.py @@ -0,0 +1,64 @@ +import math + + +# TODO 未使用 +def adjust_score_tiered( + initial_score: int, probabilities: list[float], thresholds: list[float], downgrade_levels: list[int] +) -> int: + """ + 根据大模型给出评分时的概率,对原始评分进行分级置信度调整。 + + Args: + initial_score: 大模型给出的原始评分 (整数 1 到 5)。 + probabilities: 包含 5 个评分 (1 到 5) 概率的列表。 + 例如 [P(1), P(2), P(3), P(4), P(5)]。 + thresholds: 一个降序排列的概率阈值列表,定义置信度区间边界。 + 例如 [0.6, 0.3]。 + downgrade_levels: 与 thresholds 对应的降级幅度列表,长度比 thresholds 多 1。 + 定义了每个置信度区间的降级数。例如 [0, 1, 2]。 + + Returns: + 经过置信度调整后的最终评分 (整数 1 到 5)。 + + Raises: + ValueError: 如果输入参数不合法(例如概率列表长度不对,阈值未降序等)。 + """ + # --- 输入校验 --- + if not (1 <= initial_score <= 5): + raise ValueError("initial_score 必须在 1 到 5 之间。") + if len(probabilities) != 5: + raise ValueError("probabilities 列表必须包含 5 个元素。") + # 检查概率和是否接近 1 (允许小的浮点误差) + if not math.isclose(sum(probabilities), 1.0, abs_tol=1e-6): + print(f"警告: 概率之和 {sum(probabilities)} 不接近 1.0。请检查概率来源。") # 打印警告而非直接报错 + # raise ValueError("probabilities 中元素的和必须接近 1.0。") + if len(downgrade_levels) != len(thresholds) + 1: + raise ValueError("downgrade_levels 的长度必须比 thresholds 的长度多 1。") + if any(thresholds[i] < thresholds[i + 1] for i in range(len(thresholds) - 1)): + raise ValueError("thresholds 列表必须是降序排列的。") + if any(level < 0 for level in downgrade_levels): + raise ValueError("downgrade_levels 中的降级幅度不能为负数。") + + # --- 算法核心 --- + # 1. 获取选中分数的概率 + # 列表索引从0开始,所以评分 s 对应的索引是 s-1 + try: + p_chosen = probabilities[initial_score - 1] + except IndexError: + # 这个错误理论上不应发生,因为 initial_score 已校验在 1-5 之间 + raise ValueError(f"无法从 probabilities 列表获取索引 {initial_score - 1} 的值。") + + # 2. 确定降级幅度 + downgrade = downgrade_levels[-1] # 默认为最低置信度区间的降级幅度 + # 遍历阈值列表 (从高到低) + for i in range(len(thresholds)): + if p_chosen >= thresholds[i]: + downgrade = downgrade_levels[i] # 找到对应的置信度区间 + break # 停止遍历 + + # 3. 计算调整后的评分 + preliminary_score = initial_score - downgrade + adjusted_score = max(1, preliminary_score) # 确保分数不低于 1 + + # 4. 返回结果 + return adjusted_score diff --git a/weclone/data/clean/strategies.py b/weclone/data/clean/strategies.py new file mode 100644 index 0000000000000000000000000000000000000000..d66e5f800391bc96edd9ccdea29bde428e110f4b --- /dev/null +++ b/weclone/data/clean/strategies.py @@ -0,0 +1,144 @@ +import json +import pandas as pd +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Any, Dict, List, Union +from langchain_core.prompts import PromptTemplate +from weclone.data.models import QaPair, CutMessage, QaPairScore +from weclone.prompts.clean_data import CLEAN_PROMPT +import os +from weclone.utils.log import logger + + +@dataclass +class CleaningStrategy(ABC): + """数据清洗策略的抽象基类""" + + make_dataset_config: Dict + + @abstractmethod + def clean(self, data: Any) -> Any: + """ + 执行数据清洗操作。 + + Args: + data: 需要清洗的数据。 + + Returns: + 清洗后的数据。 + """ + pass + + +@dataclass +class LLMCleaningStrategy(CleaningStrategy): + """使用大模型进行数据清洗的策略""" + + + def judge(self, data: List[QaPair]) -> None: + """ + 调用llm打分,并将分数直接赋值给传入的QaPair。 + """ + from weclone.core.inference.offline_infer import vllm_infer + logger.info("开始使用llm对数据打分") + inputs = [] + prompt_template = PromptTemplate.from_template(CLEAN_PROMPT) + for qa in data: + inputs.append(prompt_template.invoke({"id": qa.id, "Q": qa.instruction, "A": qa.output}).text) # type: ignore + outputs = vllm_infer( + inputs, + self.make_dataset_config["model_name_or_path"], + template=self.make_dataset_config["template"], + temperature=0, + guided_decoding_class=QaPairScore, + repetition_penalty=1.2, + bad_words=[r"\n"], + ) + + parsed_scores: List[QaPairScore] = [] + for result in outputs: + try: + score_data = json.loads(result.outputs[0].text) + qa_score = QaPairScore(**score_data) + parsed_scores.append(qa_score) + except json.JSONDecodeError: + logger.error(f"Error decoding JSON: {result.outputs[0].text}") + + score_map = {score.id: score.score for score in parsed_scores} + for qa in data: + if qa.id in score_map: + qa.score = score_map[qa.id] + else: + logger.warning(f"Warning: Score not found for QaPair with id {qa.id}. Assigning default score.") + + scores = [qa.score for qa in data if qa.score is not None] + score_series = pd.Series(scores) + score_counts = score_series.value_counts().sort_index() + score_percentages = score_series.value_counts(normalize=True).sort_index() * 100 + pd.set_option("display.unicode.east_asian_width", True) # 尝试修正对齐问题 + distribution_df = pd.DataFrame( # 合并数量和百分比到一个 DataFrame 中以便打印 + { + "数量": score_counts, + "占比(%)": score_percentages.round(2), + } + ) + distribution_df.index.name = "分数" # 给第一列加上列名:分数 + printable_df_str = distribution_df.reset_index().to_string(index=False) + logger.success(f"llm打分分数分布情况:\n{printable_df_str}") + + def clean(self) -> str: + """ + 清洗 SFT 数据并返回清洗后的文件路径。 + 如果未启用清洗,则返回原始路径。 + """ + config = self.make_dataset_config + dataset_dir = config["dataset_dir"] + dataset_info_path = os.path.join(dataset_dir, "dataset_info.json") + + sft_json_path = os.path.join(dataset_dir, "sft-my.json") + output_json_path = os.path.join(dataset_dir, "sft-my-l.json") + accept_score = config.get("clean_dataset", {}).get("llm", {}).get("accept_score", 1) + + if not config.get("clean_dataset", {}).get("enable_clean"): + logger.info("未启用清洗功能") + self._update_dataset_info_file(dataset_info_path, new_file_name="sft-my.json") + return sft_json_path + + try: + with open(sft_json_path, 'r', encoding='utf-8') as f: + data = json.load(f) + filtered_data = [item for item in data if item.get("score", 0) >= accept_score] + + with open(output_json_path, 'w', encoding='utf-8') as f: + json.dump(filtered_data, f, ensure_ascii=False, indent=4) + + logger.success(f"已筛出低于{accept_score}分的数据,共保留 {len(filtered_data)} 条数据") + self._update_dataset_info_file(dataset_info_path, new_file_name="sft-my-l.json") + return output_json_path + + except Exception as e: + logger.error(f"清洗数据失败,使用原始数据: {str(e)}") + self._update_dataset_info_file(dataset_info_path, new_file_name="sft-my.json") + return sft_json_path + + def _update_dataset_info_file(self, dataset_info_path: str, new_file_name: str): + """ + 修改 dataset_info.json 文件中的 file_name 字段 + """ + try: + with open(dataset_info_path, "r", encoding="utf-8") as f: + dataset_info = json.load(f) + + # 更新所有支持的数据集的 file_name + for key in ["wechat-sft", "wechat-sft-with-history"]: + if key in dataset_info: + dataset_info[key]["file_name"] = new_file_name + + # 写回文件 + with open(dataset_info_path, "w", encoding="utf-8") as f: + json.dump(dataset_info, f, indent=4, ensure_ascii=False) + + logger.info(f"已更新 dataset_info.json 中的 file_name 为 {new_file_name}") + + except Exception as e: + logger.warning(f"无法更新 dataset_info.json: {e}") diff --git a/weclone/data/clean/strategies_online.py b/weclone/data/clean/strategies_online.py new file mode 100644 index 0000000000000000000000000000000000000000..9fecd7f6003b486db2f158498406448718169d08 --- /dev/null +++ b/weclone/data/clean/strategies_online.py @@ -0,0 +1,155 @@ +import re +import json +import pandas as pd +from tqdm import tqdm +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Any, Dict, List +from langchain_core.prompts import PromptTemplate +from weclone.data.models import QaPair, QaPairScore +from weclone.prompts.clean_data import CLEAN_PROMPT,ONLINE_LLM_CLEAN_PROMPT +from weclone.core.inference.online_infer import OnlineLLM +from weclone.utils.log import logger +import os + +@dataclass +class CleaningStrategy(ABC): + """数据清洗策略的抽象基类""" + + make_dataset_config: Dict + + @abstractmethod + def clean(self, data: Any) -> Any: + pass + +@dataclass +class OlineLLMCleaningStrategy(CleaningStrategy): + """使用大模型进行数据清洗的策略""" + + def judge(self, data: List[QaPair]) -> None: + logger.info("开始使用在线模型对数据打分") + + logger.info(f"使用模型 {self.make_dataset_config.get('model_name', '')}") + + client = OnlineLLM( + api_key = self.make_dataset_config.get("llm_api_key"), + base_url = self.make_dataset_config.get("base_url"), + model_name = self.make_dataset_config.get("model_name"), + default_system = self.make_dataset_config.get("default_system") + ) + prompt_template = PromptTemplate.from_template(ONLINE_LLM_CLEAN_PROMPT) + + parsed_scores = [] + clean_batch_size = int(self.make_dataset_config.get("clean_batch_size", 10)) + for i in tqdm(range(0, len(data), clean_batch_size), desc="在线模型评分进度"): + batch = data[i : i + clean_batch_size] + # 构造当前批次的 qa_list + qa_list = [ + {"id": qa.id, "Q": qa.instruction, "A": qa.output} + for qa in batch + ] + qa_list_json = json.dumps(qa_list, ensure_ascii=False) + # 填充模板 + prompt_text = prompt_template.invoke({ + "qa_list": qa_list_json + }).text + try: + response = client.chat(prompt_text) + result_text = response.choices[0].message.content + # print("大模型返回:",result_text) + # 如果有 ,只保留 之后的内容 + if "" in result_text: + result_text = result_text.split("", 1)[1] + # 去掉开头和结尾的 ```json 或 ``` 等代码块标记 + result_text = re.sub(r"^```json\s*|```$", "", result_text.strip(), flags=re.MULTILINE) + # 如果偶尔的几次解析失败就跳过 + try: + score_list = json.loads(result_text) + except json.JSONDecodeError as e: + logger.error(f"JSON 解析失败,跳过本批次: {e}\n内容:{result_text}") + continue + + for item in score_list: + parsed_scores.append(QaPairScore(**item)) + except Exception as e: + ids_in_batch = [qa["id"] for qa in qa_list] + logger.error(f"调用在线模型或解析结果失败,当前 batch QA ID 列表: {ids_in_batch},错误信息: {str(e)}") + + score_map = {score.id: score.score for score in parsed_scores} + for qa in data: + if qa.id in score_map: + qa.score = score_map[qa.id] + else: + logger.warning(f"未获取到QA ID {qa.id}的分数,默认赋值0") + qa.score = 0 + + # 统计分数分布,打印日志(和本地版本保持一致) + scores = [qa.score for qa in data if qa.score is not None] + score_series = pd.Series(scores) + score_counts = score_series.value_counts().sort_index() + score_percentages = score_series.value_counts(normalize=True).sort_index() * 100 + pd.set_option("display.unicode.east_asian_width", True) + distribution_df = pd.DataFrame({ + "数量": score_counts, + "占比(%)": score_percentages.round(2), + }) + distribution_df.index.name = "分数" + printable_df_str = distribution_df.reset_index().to_string(index=False) + logger.success(f"在线模型打分分数分布情况:\n{printable_df_str}") + + def clean(self) -> str: + """ + 清洗 SFT 数据并返回清洗后的文件路径。 + 如果未启用清洗,则返回原始路径。 + """ + config = self.make_dataset_config + dataset_dir = config["dataset_dir"] + dataset_info_path = os.path.join(dataset_dir, "dataset_info.json") + + sft_json_path = os.path.join(dataset_dir, "sft-my.json") + output_json_path = os.path.join(dataset_dir, "sft-my-l.json") + accept_score = config.get("clean_dataset", {}).get("llm", {}).get("accept_score", 1) + + if not config.get("clean_dataset", {}).get("enable_clean"): + logger.info("未启用清洗功能") + self._update_dataset_info_file(dataset_info_path, new_file_name="sft-my.json") + return sft_json_path + + try: + with open(sft_json_path, 'r', encoding='utf-8') as f: + data = json.load(f) + filtered_data = [item for item in data if item.get("score", 0) >= accept_score] + + with open(output_json_path, 'w', encoding='utf-8') as f: + json.dump(filtered_data, f, ensure_ascii=False, indent=4) + + logger.success(f"已筛出低于{accept_score}分的数据,共保留 {len(filtered_data)} 条数据") + self._update_dataset_info_file(dataset_info_path, new_file_name="sft-my-l.json") + return output_json_path + + except Exception as e: + logger.error(f"清洗数据失败,使用原始数据: {str(e)}") + self._update_dataset_info_file(dataset_info_path, new_file_name="sft-my.json") + return sft_json_path + + def _update_dataset_info_file(self, dataset_info_path: str, new_file_name: str): + """ + 修改 dataset_info.json 文件中的 file_name 字段 + """ + try: + with open(dataset_info_path, "r", encoding="utf-8") as f: + dataset_info = json.load(f) + + # 更新所有支持的数据集的 file_name + for key in ["wechat-sft", "wechat-sft-with-history"]: + if key in dataset_info: + dataset_info[key]["file_name"] = new_file_name + + # 写回文件 + with open(dataset_info_path, "w", encoding="utf-8") as f: + json.dump(dataset_info, f, indent=4, ensure_ascii=False) + + logger.info(f"已更新 dataset_info.json 中的 file_name 为 {new_file_name}") + + except Exception as e: + logger.warning(f"无法更新 dataset_info.json: {e}") diff --git a/weclone/data/models.py b/weclone/data/models.py new file mode 100644 index 0000000000000000000000000000000000000000..eb6eb0df3af31fa7d8587642854b8290b7ac02a0 --- /dev/null +++ b/weclone/data/models.py @@ -0,0 +1,74 @@ +from dataclasses import dataclass +from pandas import Timestamp +from pydantic import BaseModel + + +@dataclass +class ChatMessage: + id: int + MsgSvrID: int + type_name: str + is_sender: int + talker: str + room_name: str + msg: str + src: str + CreateTime: Timestamp + + +@dataclass +class CutMessage: + is_sender: int + cut_type: str + CreateTime: Timestamp + + +@dataclass +class QaPair: + id: int + system: str + instruction: str + output: str + history: list[list[str]] + time: Timestamp + score: int + + +class QaPairScore(BaseModel): + id: int + score: int + + +skip_type_list = [ + "添加好友", + "推荐公众号", + "动画表情", + "位置", + "文件", + "位置共享", + "接龙", + "引用回复", + "视频号直播或直播回放", + "用户上传的GIF表情", + "文件(猜)", + "群公告", + "视频号直播或直播回放等", + "游戏相关", + "转账", + "赠送红包封面", + "语音通话", + "企业微信打招呼(猜)", + "企业微信添加好友(猜)", + "系统通知", + "消息撤回1", + "拍一拍", + "消息撤回5", + "消息撤回6", + "消息撤回33", + "消息撤回36", + "消息撤回57", + "邀请加群", + "未知-11000,0", +] +# 没处理的类型 +unprocessed_type_list = [] diff --git a/weclone/data/qa_generator.py b/weclone/data/qa_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..8575c01483a6eb261a93d250eafc4cd0f495049e --- /dev/null +++ b/weclone/data/qa_generator.py @@ -0,0 +1,506 @@ +import os +import sys +import subprocess +from typing import Dict, List, Union +import re + +import pandas as pd +import json +from pandas import Timestamp +from llamafactory.extras.packages import is_vllm_available + +from weclone.data.clean.strategies import LLMCleaningStrategy +from weclone.data.clean.strategies_online import OlineLLMCleaningStrategy +from weclone.utils.config import load_config +from weclone.utils.log import logger +from weclone.data.models import ChatMessage, CutMessage, skip_type_list, QaPair +from weclone.data.strategies import TimeWindowStrategy, LLMStrategy + + +class DataProcessor: + def __init__(self): + self.config = load_config(arg_type="make_dataset") + self.csv_folder = "./dataset/csv" + self.system_prompt = self.config["default_system"] + self.cut_type_list = [ + "图片", + "视频", + "合并转发的聊天记录", + "语音", + "(分享)音乐", + "(分享)卡片式链接", + "(分享)笔记", + "(分享)小程序", + "(分享)收藏夹", + "(分享)小说(猜)", + "(分享)视频号名片", + "(分享)视频号视频", + "粘贴的文本", # 无法解析的分享链接 + ] + + # blocked_words + config_blocked_words = self.config.get("blocked_words", []) + file_blocked_words = [] + try: + with open("./dataset/blocked_words.json", encoding="utf-8") as f: + file_blocked_words = json.load(f).get("blocked_words", []) + except (FileNotFoundError, json.JSONDecodeError): + pass + + self.blocked_words = list(set(config_blocked_words + file_blocked_words)) + # logger.info(f"聊天记录禁用词: {self.blocked_words}") + + if self.config["single_combine_strategy"] == "time_window": + self.single_combine_strategy = TimeWindowStrategy( + time_window=self.config["single_combine_time_window"] * 60, + is_single_chat=True, + ) + elif self.config["single_combine_strategy"] == "llm": + self.single_combine_strategy = LLMStrategy( + is_single_chat=True, + ) + + if self.config["qa_match_strategy"] == "time_window": + self.qa_match_strategy = TimeWindowStrategy( + time_window=self.config["qa_match_time_window"] * 60, + is_single_chat=False, + ) + elif self.config["qa_match_strategy"] == "llm": + self.qa_match_strategy = LLMStrategy(is_single_chat=False) + + clean_dataset_config = self.config.get("clean_dataset", {}) + enable_clean = clean_dataset_config.get("enable_clean", False) + + if enable_clean: + if self.config.get("prompt_with_history", False): + logger.warning("开启 prompt_with_history 不支持 clean_dataset 功能") + exit() + + if not is_vllm_available() and not self.config.get("online_llm_clear"): + logger.warning("vLLM 不可用,暂不清洗数据集。") + clean_dataset_config["enable_clean"] = False + + if self.config.get("clean_dataset", {}).get("enable_clean", False): + if self.config.get("clean_dataset", {}).get("clean_strategy", "llm") == "llm": + if self.config.get("online_llm_clear"): + self.clean_strategy = OlineLLMCleaningStrategy(make_dataset_config=self.config) + else: + self.clean_strategy = LLMCleaningStrategy(make_dataset_config=self.config) + self.c = self.config + + def main(self): + if not os.path.exists(self.csv_folder) or not os.listdir(self.csv_folder): + logger.error(f"错误:目录 '{self.csv_folder}' 不存在或为空,请检查路径并确保其中包含 CSV 聊天数据文件。") + return + + csv_files = self.get_csv_files() + logger.info(f"共发现 {len(csv_files)} 个 CSV 文件,开始处理") + message_list: List[ChatMessage] = [] + for csv_file in csv_files: + logger.debug(f"开始处理 CSV 文件: {csv_file}") + chat_messages = self.load_csv(csv_file) + message_list.extend(self.group_consecutive_messages(messages=chat_messages)) + # self.process_by_msgtype(chat_message) + logger.debug(f"处理完成: {csv_file},共加载 {len(chat_messages)} 条消息") + qa_res = self.match_qa(message_list) + if self.c["prompt_with_history"]: + qa_res = self.add_history_to_qa(qa_res) + else: + qa_res = [item for item in qa_res if isinstance(item, QaPair)] + + if self.c.get("clean_dataset", {}).get("enable_clean", False): + self.clean_strategy.judge(qa_res) + # qa_res = self.clean_strategy.clean(qa_res) + self.save_result(qa_res) + self._execute_length_cdf_script() + + logger.success(f"聊天记录处理成功,共{len(qa_res)}条,保存到 ./dataset/res_csv/sft/sft-my.json") + + def _execute_length_cdf_script(self): + """执行 length_cdf.py 脚本来计算cutoff_len。""" + try: + python_executable = sys.executable + # 脚本路径是相对于项目根目录的 + script_path = os.path.join("weclone", "utils", "length_cdf.py") + + command_parts = [ + python_executable, + script_path, + f'--model_name_or_path="{self.c["model_name_or_path"]}"', + f'--dataset="{self.c["dataset"]}"', + f'--dataset_dir="{self.c["dataset_dir"]}"', + f'--template="{self.c["template"]}"', + f"--interval={self.c['cutoff_len']}", + ] + + child_env = os.environ.copy() + child_env["CUDA_VISIBLE_DEVICES"] = "0" + child_env["LLAMAFACTORY_VERBOSITY"] = "ERROR" + + process = subprocess.Popen( + command_parts, + env=child_env, + stdout=None, # 使用 None 表示使用父进程的标准输出(即终端) + stderr=None, # 使用 None 表示使用父进程的标准错误(即终端) + text=True, + bufsize=1, # 行缓冲 + ) + return_code = process.wait() + if return_code != 0: + logger.error(f"命令 '{' '.join(command_parts)}' 执行失败,返回码 {return_code}") + except FileNotFoundError: + # command_parts[0] 是 python_executable, command_parts[1] 是 script_path + logger.error(f"命令执行失败: 找不到可执行文件 '{command_parts[0]}' 或脚本 '{command_parts[1]}'") + except KeyError as e: + logger.error(f"执行 length_cdf.py 脚本失败:配置项缺失 {str(e)}") + except Exception as e: + logger.error(f"执行 length_cdf.py 脚本时发生未知错误: {str(e)}") + + def get_csv_files(self): + """遍历文件夹获取所有CSV文件路径,并按文件名中的起始序号排序""" + + csv_files = [] + for chat_obj_folder in os.listdir(self.csv_folder): + chat_obj_folder_path = os.path.join(self.csv_folder, chat_obj_folder) + for csvfile in os.listdir(chat_obj_folder_path): + if not csvfile.endswith(".csv"): + continue + csvfile_path = os.path.join(chat_obj_folder_path, csvfile) + csv_files.append(csvfile_path) + # 提取文件名中的起始数字,比如 wxid_..._0_5000.csv → 0 + pattern = re.compile(r"_(\d+)_\d+\.csv$") + + def extract_start(fp: str) -> int: + name = os.path.basename(fp) + m = pattern.search(name) + return int(m.group(1)) if m else 0 + + # 按起始数字升序排序 + csv_files.sort(key=extract_start) + return csv_files + + def match_qa(self, messages: List[ChatMessage]) -> List[Union[QaPair, CutMessage]]: + """ + 匹配问答对 + + Args: + messages: 消息列表 + + Returns: + List[Union[QaPair, CutMessage]]: 包含指令和输出的问答对列表 + """ + # 状态定义 + WAITING_INSTRUCTION = "waiting_instruction" # 等待指令 + WAITING_RESPONSE = "waiting_response" # 等待回复 + + current_state = WAITING_INSTRUCTION + qa_res: List[Union[QaPair, CutMessage]] = [] + last_message = None + current_instruction = None + qa_id_counter = 0 + + for msg in messages: + if isinstance(msg, CutMessage): + current_state = WAITING_INSTRUCTION + current_instruction = None + last_message = None + if self.c["prompt_with_history"]: + qa_res.append(msg) + continue + + if current_state == WAITING_INSTRUCTION: + if msg.is_sender == 0: # 收到对方消息 + current_instruction = msg.msg + last_message = msg + current_state = WAITING_RESPONSE + + elif current_state == WAITING_RESPONSE: + if msg.is_sender == 0: # 收到对方消息 + current_instruction = msg.msg + last_message = msg + # 状态保持不变 + else: # 自己的回复 使用策略判断是否属于同一对话 + if last_message and self.qa_match_strategy.is_same_conversation([last_message], msg): + assert current_instruction is not None, ( + "current_instruction should not be None when creating a QA pair" + ) + qa_pair = QaPair( + id=qa_id_counter, + system=self.system_prompt, + instruction=current_instruction, + output=msg.msg, + history=[], # No history in this context yet + time=msg.CreateTime, # Use the response message time + score=0, # Default score + ) + qa_res.append(qa_pair) + qa_id_counter += 1 # 增加计数器 + else: + if self.c["prompt_with_history"]: + qa_res.append( + CutMessage( + is_sender=msg.is_sender, + cut_type=msg.type_name, + CreateTime=msg.CreateTime, + ) + ) + # 无论是否匹配,都重置状态 + current_state = WAITING_INSTRUCTION + current_instruction = None + last_message = None + + return qa_res + + # TODO: need review + def add_history_to_qa(self, qa_res: List[Union[QaPair, CutMessage]]) -> List[QaPair]: + """ + Adds conversation history to QaPair objects. + + Args: + qa_res: A list containing QaPair and CutMessage objects. + + Returns: + A list of QaPair objects with history populated. + """ + qa_res_with_history: List[QaPair] = [] + current_history: List[List[str]] = [] + last_timestamp: Timestamp = None # type: ignore + + for item in qa_res: + if isinstance(item, CutMessage): + if current_history: + instruction = current_history[-1][0] + output = current_history[-1][1] + history = current_history[:-1] + qa_pair_with_history = QaPair( + id=-1, + system=self.system_prompt, + instruction=instruction, + output=output, + history=history, + time=last_timestamp, + score=0, + ) + qa_res_with_history.append(qa_pair_with_history) + current_history = [] + last_timestamp = None # type: ignore + elif isinstance(item, QaPair): + current_history.append([item.instruction, item.output]) + last_timestamp = item.time + + if current_history: + instruction = current_history[-1][0] + output = current_history[-1][1] + history = current_history[:-1] + # Ensure last_timestamp is not None before assignment + final_timestamp_end = last_timestamp + assert final_timestamp_end is not None, "Timestamp cannot be None for the final QaPair" + qa_pair_with_history = QaPair( + id=-1, + system=self.system_prompt, + instruction=instruction, + output=output, + history=history, + time=final_timestamp_end, + score=0, + ) + qa_res_with_history.append(qa_pair_with_history) + + return qa_res_with_history + + def group_consecutive_messages(self, messages: List[ChatMessage]) -> List[ChatMessage]: + """ + 将同一个人连续发送的多条消息组合成一条消息,遇到cut_type添加cut + + Args: + messages: 消息列表 + + Returns: + List[ChatMessage]: 组合后的消息列表 + """ + if not messages: + return [] + + def _combine_text(messages: List[ChatMessage]) -> ChatMessage: + """ + 合并多条消息为一条 + + Args: + messages: 要合并的消息列表 + + Returns: + ChatMessage: 合并后的消息 + """ + base_msg = messages[0] + combined_content = messages[0].msg + + for i in messages[1:]: + content = i.msg + if not content: + continue + + if combined_content and combined_content[-1] not in ["。", "!", "?", "…", ",", "."]: + combined_content += "," + + combined_content += content + if len(combined_content) > self.c["combine_msg_max_length"]: + logger.warning( + f"组合后消息长度超过{self.c['combine_msg_max_length']}将截断:\n {combined_content[:50]}" + ) + combined_content = combined_content[: self.c["combine_msg_max_length"]] + + combined_message = ChatMessage( + id=base_msg.id, + MsgSvrID=base_msg.MsgSvrID, + type_name=base_msg.type_name, + is_sender=base_msg.is_sender, + talker=base_msg.talker, + room_name=base_msg.room_name, + msg=combined_content, + src=base_msg.src, + CreateTime=messages[-1].CreateTime, # 使用最后一条消息的时间 + ) + + return combined_message + + def _create_cut_message(message: ChatMessage) -> CutMessage: + return CutMessage( + is_sender=message.is_sender, + cut_type=message.type_name, + CreateTime=message.CreateTime, + ) + + def _combine_current_group(group): + """ + 处理当前消息组并添加到grouped_messages + + Args: + group: 当前消息组 + """ + if len(group) > 1: + combined_msg = _combine_text(group) + grouped_messages.append(combined_msg) + else: + grouped_messages.append(group[0]) + + grouped_messages = [] + current_group = [] + + for _, current_msg in enumerate(messages): + if current_msg.type_name in self.cut_type_list: + if current_group: + # 当前组有消息,合并当前组,并添加一条cut + _combine_current_group(current_group) + current_group = [] + + cut_msg = _create_cut_message(current_msg) + grouped_messages.append(cut_msg) + else: + # 当前组没消息,检查上一个组 + if grouped_messages: + if not isinstance(grouped_messages[-1], CutMessage): + cut_msg = _create_cut_message(current_msg) + grouped_messages.append(cut_msg) + # 如果上一个组没消息或最后一条是CutMessage,直接continue + continue + + if not current_group: + current_group = [current_msg] + continue + + last_msg = current_group[-1] + + # 判断是否是同一个人的连续消息 + if ( + current_msg.is_sender == last_msg.is_sender + and current_msg.talker == last_msg.talker + and self.single_combine_strategy.is_same_conversation([last_msg], current_msg) + ): + current_group.append(current_msg) + else: + # 不是同一个人的消息,处理当前组并开始新组 + _combine_current_group(current_group) + # 开始新组 + current_group = [current_msg] + + # 处理最后一组消息 + if current_group: + _combine_current_group(current_group) + + return grouped_messages + + def process_by_msgtype(self, chat_message: ChatMessage): + if chat_message.type_name == "文本": + self.process_text(chat_message) + # elif chat_message.type_name == "图片": + # self.process_image(chat_message) + + def load_csv(self, file_path) -> List[ChatMessage]: + """ + 做整体第一次预处理,过滤不符合条件的行 + """ + df = pd.read_csv(file_path, encoding="utf-8", dtype={"msg": str}) + + df = df[~df["type_name"].isin(values=skip_type_list)] + + # 如果type_name为文本 并且msg 包含 手机号、身份证号、邮箱、网址则删除这行 + for i in df.index: + if df.loc[i, "type_name"] == "文本": + msg_str = str(df.loc[i, "msg"]) + if ( + re.search(r"1\d{10}", msg_str) + or re.search(r"\d{18}", msg_str) + or re.search(r"\w+@\w+", msg_str) + or "http" in msg_str + or r"\\xa0" in msg_str + or r"\\u" in msg_str + ): + df = df.drop(index=i) + continue + for blocked_word in self.blocked_words: + if blocked_word in msg_str: + df = df.drop(index=i) + break + else: + df.loc[i, "msg"] = "" + + df = df.dropna(how="all") + # 时间格式 2021-07-07 10:27:23 + # 遍历行 相同is_sender的行合并msg()遇到不同is_sender就重新开始 + df["CreateTime"] = pd.to_datetime(df["CreateTime"]) + + return [ChatMessage(*row) for row in df.values] + + def process_text(self, chat_message: ChatMessage): + pass + + def save_result(self, qa_res: List[QaPair]): + """ + Saves the list of QaPair objects to a JSON file after converting them to dictionaries. + + Args: + qa_res: A list of QaPair objects. + """ + processed_qa_res = [] + for idx, item in enumerate(qa_res): + item_dict = { + "id": idx, + "system": item.system, + "instruction": item.instruction, + "output": item.output, + "history": item.history, + "time": item.time.isoformat() if item.time else None, + "score": item.score, + } + processed_qa_res.append(item_dict) + + output_path = "./dataset/res_csv/sft/sft-my.json" + os.makedirs(os.path.dirname(output_path), exist_ok=True) + with open(output_path, "w", encoding="utf-8") as f: + json.dump(processed_qa_res, f, ensure_ascii=False, indent=4) + logger.success(f"聊天记录处理成功,共{len(qa_res)}条,保存到 {output_path}") + + +if __name__ == "__main__": + processor = DataProcessor() + processor.main() diff --git a/weclone/data/strategies.py b/weclone/data/strategies.py new file mode 100644 index 0000000000000000000000000000000000000000..e6b313755b9b2cd32a57d1d20179c232b9c98fbe --- /dev/null +++ b/weclone/data/strategies.py @@ -0,0 +1,60 @@ +from dataclasses import dataclass +from typing import List +from .models import ChatMessage +from abc import ABC, abstractmethod + + +@dataclass +class ConversationStrategy(ABC): + """对话策略的抽象基类""" + + is_single_chat: bool + + @abstractmethod + def is_same_conversation( + self, history_msg: List[ChatMessage], current_msg: ChatMessage + ) -> bool: + """判断两条消息是否属于同一个对话""" + pass + + +@dataclass +class TimeWindowStrategy(ConversationStrategy): + """基于时间窗口的判断策略""" + + time_window: int # 时间窗口(分钟) + + def is_same_conversation( + self, history_msg: List[ChatMessage], current_msg: ChatMessage + ) -> bool: + time_diff = abs( + (current_msg.CreateTime - history_msg[-1].CreateTime) + ).total_seconds() + return time_diff <= self.time_window + + +@dataclass +class LLMStrategy(ConversationStrategy): + """基于大模型判断策略""" + + def is_same_conversation( + self, history_msg: List[ChatMessage], current_msg: ChatMessage + ) -> bool: + # 修复user_id错误,使用talker字段代替user_id + return current_msg.talker == history_msg[-1].talker if history_msg else False + + +@dataclass +class CompositeStrategy(ConversationStrategy): + """组合多个策略的复合策略""" + + strategies: List[ConversationStrategy] + require_all: bool = True # True表示所有策略都满足,False表示任一策略满足即可 + + def is_same_conversation( + self, history_msg: List[ChatMessage], current_msg: ChatMessage + ) -> bool: + results = [ + s.is_same_conversation(history_msg, current_msg) for s in self.strategies + ] + return all(results) if self.require_all else any(results) diff --git a/weclone/eval/__init__.py b/weclone/eval/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/weclone/eval/cli_demo.py b/weclone/eval/cli_demo.py new file mode 100644 index 0000000000000000000000000000000000000000..0142278d6d825c31cb0f4576b69b2b0f605b910e --- /dev/null +++ b/weclone/eval/cli_demo.py @@ -0,0 +1,48 @@ +from llamafactory.chat import ChatModel +from llamafactory.extras.misc import torch_gc + + +def main(): + try: + import platform + + if platform.system() != "Windows": + import readline # noqa: F401 + except ImportError: + print("Install `readline` for a better experience.") + + chat_model = ChatModel() + messages = [] + print("Welcome to the CLI application, use `clear` to remove the history, use `exit` to exit the application.") + + while True: + try: + query = input("\nUser: ") + except UnicodeDecodeError: + print("Detected decoding error at the inputs, please set the terminal encoding to utf-8.") + continue + except Exception: + raise + + if query.strip() == "exit": + break + + if query.strip() == "clear": + messages = [] + torch_gc() + print("History has been removed.") + continue + + messages.append({"role": "user", "content": query}) + print("Assistant: ", end="", flush=True) + + response = "" + for new_text in chat_model.stream_chat(messages): + print(new_text, end="", flush=True) + response += new_text + print() + messages.append({"role": "assistant", "content": response}) + + +if __name__ == "__main__": + main() diff --git a/weclone/eval/eval_model.py b/weclone/eval/eval_model.py new file mode 100644 index 0000000000000000000000000000000000000000..0cfd81137daee58a1600a36f50991682e953c5a9 --- /dev/null +++ b/weclone/eval/eval_model.py @@ -0,0 +1,10 @@ +from llamafactory.eval.evaluator import Evaluator + + +def main(): + evaluator = Evaluator() + evaluator.eval() + + +if __name__ == "__main__": + main() diff --git a/weclone/eval/test_model.py b/weclone/eval/test_model.py new file mode 100644 index 0000000000000000000000000000000000000000..6517634425c0c301d2b4c83e1caad51df1109f47 --- /dev/null +++ b/weclone/eval/test_model.py @@ -0,0 +1,70 @@ +import json +import openai +from openai import OpenAI # 导入 OpenAI 类 + +from tqdm import tqdm +from typing import List, Dict, cast # 导入 cast +from openai.types.chat import ChatCompletionMessageParam # 导入消息参数类型 + +from weclone.utils.config import load_config + +config = load_config("web_demo") + +config = { + "default_prompt": config["default_system"], + "model": "gpt-3.5-turbo", + "history_len": 15, +} + +config = type("Config", (object,), config)() + +# 初始化 OpenAI 客户端 +client = OpenAI( + api_key="""sk-test""", + base_url="http://127.0.0.1:8005/v1" +) + + +def handler_text(content: str, history: list, config): + messages = [{"role": "system", "content": f"{config.default_prompt}"}] + for item in history: + messages.append(item) + messages.append({"role": "user", "content": content}) + history.append({"role": "user", "content": content}) + try: + # 使用新的 API 调用方式 + # 将 messages 转换为正确的类型 + typed_messages = cast(List[ChatCompletionMessageParam], messages) + response = client.chat.completions.create( + model=config.model, + messages=typed_messages, # 传递转换后的列表 + max_tokens=50 + ) + except openai.APIError as e: + history.pop() + return "AI接口出错,请重试\n" + str(e) + + resp = str(response.choices[0].message.content) # type: ignore + resp = resp.replace("\n ", "") + history.append({"role": "assistant", "content": resp}) + return resp + + +def main(): + test_list = json.loads(open("dataset/test_data.json", "r", encoding="utf-8").read())["questions"] + res = [] + for questions in tqdm(test_list, desc=" Testing..."): + history = [] + for q in questions: + handler_text(q, history=history, config=config) + res.append(history) + + res_file = open("test_result-my.txt", "w") + for r in res: + for i in r: + res_file.write(i["content"] + "\n") + res_file.write("\n") + + +if __name__ == "__main__": + main() diff --git a/weclone/eval/web_demo.py b/weclone/eval/web_demo.py new file mode 100644 index 0000000000000000000000000000000000000000..de2eaba903ebd274d7774df44eb9ec191b343808 --- /dev/null +++ b/weclone/eval/web_demo.py @@ -0,0 +1,13 @@ +from llamafactory.webui.interface import create_web_demo +from weclone.utils.config import load_config + + +def main(): + config = load_config("web_demo") + demo = create_web_demo() + demo.queue() + demo.launch(server_name="0.0.0.0", share=True, inbrowser=True) + + +if __name__ == "__main__": + main() diff --git a/weclone/prompts/__init__.py b/weclone/prompts/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/weclone/prompts/clean_data.py b/weclone/prompts/clean_data.py new file mode 100644 index 0000000000000000000000000000000000000000..048d1579d993fa8bf77b661c4ec897f5e0b269e5 --- /dev/null +++ b/weclone/prompts/clean_data.py @@ -0,0 +1,80 @@ +CLEAN_PROMPT = """ +# 角色 +你是一个数据质量评估员。 + +# 任务 +你的任务是评估下面提供的【回答 A】相对于【问题/上下文 Q】的**逻辑性**和**相关性**。目标是识别并帮助过滤掉那些回答与问题**明显不匹配**、**逻辑严重混乱**的数据对。请根据以下核心评估点给出一个1到5的整数分数,并将该分数与原始 `id` 一起输出。 + +**重要考量:** +1. **简短回答的有效性:** 请注意,诸如“好的”、“是的”、“收到”、“嗯”、“知道了”等简短的肯定、确认或应答,在合适的语境下是完全**有逻辑且相关的**。**不要仅仅因为回答简短就将其评为低分。** 只有当这类简短回答与【问题/上下文 Q】**明显不符**时,才应考虑低分。 +2. **处理错别字和自我纠正:** 聊天记录中可能包含常见的打字错误(错别字)或用户先打错字随后又自行纠正的情况(例如,发送“我想去1楼”紧接着又发送“*2楼”进行更正)。在评估时,请**聚焦于用户想要表达的最终意图和信息的核心内容**,而**不应仅仅因为存在错别字或纠正过程就判定为低质量**。。 + + +# 核心评估点 (请在心中衡量) +1. **相关性 (Relevance):** 【回答 A】是否直接回应或恰当地衔接了【问题/上下文 Q】?它是在回答问题,还是完全跑题了?只有当【回答 A】与【问题/上下文 Q】**明显矛盾**、**完全不着边际**(即使考虑上下文也无法合理化),或简短回答**明显不适用于**该【问题/上下文 Q】时,才给予低分。 +2. **逻辑性 (Coherence):** 【回答 A】本身是否符合基本的逻辑?结合【问题/上下文 Q】来看,这个问答对是否构成了一个符合逻辑的交流片段?是否存在明显的矛盾、混乱的内容?只有当【回答 A】**自身逻辑混乱**、**与Q存在无法解释的矛盾**时,才给予低分。 + +# 评分标准 (1-5分) +* **1分 (极差):** 完全不相关;逻辑严重混乱/矛盾。 +* **2分 (差):** 相关性很低;存在明显的逻辑问题或不连贯。 +* **3分 (中等):** 相关性一般(可能部分跑题或回应不充分);逻辑上勉强说得通但不够流畅或有瑕疵。 +* **4分 (良好):** 相关性好,回答了问题或恰当衔接;逻辑清晰。 +* **5分 (优秀):** 相关性强,回应精准;逻辑严谨流畅。 + +# 输入数据 +```json +{{ + "id": "{id}", + "Q": "{Q}", + "A": "{A}" +}} + +# 输出要求 +请严格按照以下 JSON 格式输出,包含原始的 id 和你给出的1到5的整数评分 score,不要包含任何其他文字、解释或标签。 +{{ + "id": "<这里填入输入数据中的id值>", + "score": <这里填入1到5的整数评分> +}} +""" + +ONLINE_LLM_CLEAN_PROMPT = """ +# 角色 +你是一个数据质量评估员。 + +# 任务 +你的任务是评估下面提供的【回答 A】相对于【问题/上下文 Q】的**逻辑性**和**相关性**。目标是识别并帮助过滤掉那些回答与问题**明显不匹配**、**逻辑严重混乱**的数据对。请根据以下核心评估点给出一个1到5的整数分数,并将该分数与原始 `id` 一起输出。 + +**重要考量:** +1. **简短回答的有效性:** 请注意,诸如“好的”、“是的”、“收到”、“嗯”、“知道了”等简短的肯定、确认或应答,在合适的语境下是完全**有逻辑且相关的**。**不要仅仅因为回答简短就将其评为低分。** 只有当这类简短回答与【问题/上下文 Q】**明显不符**时,才应考虑低分。 +2. **处理错别字和自我纠正:** 聊天记录中可能包含常见的打字错误(错别字)或用户先打错字随后又自行纠正的情况(例如,发送“我想去1楼”紧接着又发送“*2楼”进行更正)。在评估时,请**聚焦于用户想要表达的最终意图和信息的核心内容**,而**不应仅仅因为存在错别字或纠正过程就判定为低质量**。。 + + +# 核心评估点 (请在心中衡量) +1. **相关性 (Relevance):** 【回答 A】是否直接回应或恰当地衔接了【问题/上下文 Q】?它是在回答问题,还是完全跑题了?只有当【回答 A】与【问题/上下文 Q】**明显矛盾**、**完全不着边际**(即使考虑上下文也无法合理化),或简短回答**明显不适用于**该【问题/上下文 Q】时,才给予低分。 +2. **逻辑性 (Coherence):** 【回答 A】本身是否符合基本的逻辑?结合【问题/上下文 Q】来看,这个问答对是否构成了一个符合逻辑的交流片段?是否存在明显的矛盾、混乱的内容?只有当【回答 A】**自身逻辑混乱**、**与Q存在无法解释的矛盾**时,才给予低分。 + +# 评分标准 (1-5分) +* **1分 (极差):** 完全不相关;逻辑严重混乱/矛盾。 +* **2分 (差):** 相关性很低;存在明显的逻辑问题或不连贯。 +* **3分 (中等):** 相关性一般(可能部分跑题或回应不充分);逻辑上勉强说得通但不够流畅或有瑕疵。 +* **4分 (良好):** 相关性好,回答了问题或恰当衔接;逻辑清晰。 +* **5分 (优秀):** 相关性强,回应精准;逻辑严谨流畅。 + +# 输入数据 +```json +{qa_list} + +# 输出要求 +请严格按照以下 JSON 格式输出,包含原始的 id 和你给出的1到5的整数评分 score,不要包含任何其他文字、解释或标签! +[ + {{ + "id": "<这里填入第1条输入数据中的id值>", + "score": <1-5的整数评分> + }}, + {{ + "id": "<这里填入第2条输入数据中的id值>", + "score": <1-5的整数评分> + }} + … +] +""" \ No newline at end of file diff --git a/weclone/server/__init__.py b/weclone/server/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/weclone/server/api_service.py b/weclone/server/api_service.py new file mode 100644 index 0000000000000000000000000000000000000000..addefc6a8538a868fa64e3e79eb97d3e4991c74c --- /dev/null +++ b/weclone/server/api_service.py @@ -0,0 +1,18 @@ +import os +import uvicorn +from llamafactory.chat import ChatModel +from llamafactory.api.app import create_app +from weclone.utils.config import load_config + + + +def main(): + config = load_config("api_service") + chat_model = ChatModel(config) + app = create_app(chat_model) + print("Visit http://localhost:{}/docs for API document.".format(os.environ.get("API_PORT", 8005))) + uvicorn.run(app, host="0.0.0.0", port=int(os.environ.get("API_PORT", 8005)), workers=1) + + +if __name__ == "__main__": + main() diff --git a/weclone/train/__init__.py b/weclone/train/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/weclone/train/export_model.py b/weclone/train/export_model.py new file mode 100644 index 0000000000000000000000000000000000000000..5a24eeef5e0160f0b4fe27a3dcc4bb4ec384fc94 --- /dev/null +++ b/weclone/train/export_model.py @@ -0,0 +1,9 @@ +from llamafactory.train.tuner import export_model + + +def main(): + export_model() + + +if __name__ == "__main__": + main() diff --git a/weclone/train/train_pt.py b/weclone/train/train_pt.py new file mode 100644 index 0000000000000000000000000000000000000000..12547c83e3e5987464b9533d73e60b67a066d25b --- /dev/null +++ b/weclone/train/train_pt.py @@ -0,0 +1,5 @@ +from llamafactory.train.tuner import run_exp +from weclone.utils.config import load_config + +config = load_config("train_pt") +run_exp(config) diff --git a/weclone/train/train_sft.py b/weclone/train/train_sft.py new file mode 100644 index 0000000000000000000000000000000000000000..07ab3d710c222d3bda2e604e38e63648d3cccfdf --- /dev/null +++ b/weclone/train/train_sft.py @@ -0,0 +1,32 @@ +import os +import sys +import json +from llamafactory.train.tuner import run_exp +from llamafactory.extras.misc import get_current_device +from weclone.utils.config import load_config +from weclone.utils.log import logger +from weclone.data.clean.strategies import LLMCleaningStrategy + +def main(): + train_config = load_config(arg_type="train_sft") + dataset_config = load_config(arg_type="make_dataset") + + device = get_current_device() + if device == "cpu": + logger.warning("请注意你正在使用CPU训练,非Mac设备可能会出现问题") + + cleaner = LLMCleaningStrategy(make_dataset_config=dataset_config) + cleaned_data_path = cleaner.clean() + + if not os.path.exists(cleaned_data_path): + logger.error(f"错误:文件 '{cleaned_data_path}' 不存在,请确保数据处理步骤已正确生成该文件。") + sys.exit(1) + + formatted_config = json.dumps(train_config, indent=4, ensure_ascii=False) + logger.info(f"微调配置:\n{formatted_config}") + + run_exp(train_config) + + +if __name__ == "__main__": + main() diff --git a/weclone/utils/__init__.py b/weclone/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/weclone/utils/config.py b/weclone/utils/config.py new file mode 100644 index 0000000000000000000000000000000000000000..7092c511024b14fae9ed5115102be909abfae10b --- /dev/null +++ b/weclone/utils/config.py @@ -0,0 +1,53 @@ +import os +import commentjson +import sys + +from .log import logger +from .tools import dict_to_argv + + +def load_config(arg_type: str): + config_path = os.environ.get("WECLONE_CONFIG_PATH", "./settings.jsonc") + logger.info(f"Loading configuration from: {config_path}") # Add logging to see which file is loaded + try: + with open(config_path, "r", encoding="utf-8") as f: + s_config: dict = commentjson.load(f) + except FileNotFoundError: + logger.error(f"Configuration file not found: {config_path}") + sys.exit(1) # Exit if config file is not found + except Exception as e: + logger.error(f"Error loading configuration file {config_path}: {e}") + sys.exit(1) + + if arg_type == "cli_args": + config = s_config["cli_args"] + elif arg_type == "web_demo" or arg_type == "api_service": + # infer_args和common_args求并集 + config = {**s_config["infer_args"], **s_config["common_args"]} + elif arg_type == "train_pt": + config = {**s_config["train_pt_args"], **s_config["common_args"]} + elif arg_type == "train_sft": + config = {**s_config["train_sft_args"], **s_config["common_args"]} + if s_config["make_dataset_args"]["prompt_with_history"]: + dataset_info_path = os.path.join(config["dataset_dir"], "dataset_info.json") + dataset_info = commentjson.load(open(dataset_info_path, "r", encoding="utf-8"))[config["dataset"]] + if dataset_info["columns"].get("history") is None: + logger.warning(f"{config['dataset']}数据集不包history字段,尝试使用wechat-sft-with-history数据集") + config["dataset"] = "wechat-sft-with-history" + + elif arg_type == "make_dataset": + config = {**s_config["make_dataset_args"], **s_config["common_args"]} + config["dataset"] = s_config["train_sft_args"]["dataset"] + config["dataset_dir"] = s_config["train_sft_args"]["dataset_dir"] + config["cutoff_len"] = s_config["train_sft_args"]["cutoff_len"] + else: + raise ValueError("暂不支持的参数类型") + + if "train" in arg_type: + config["output_dir"] = config["adapter_name_or_path"] + config.pop("adapter_name_or_path") + config["do_train"] = True + + sys.argv += dict_to_argv(config) + + return config diff --git a/weclone/utils/length_cdf.py b/weclone/utils/length_cdf.py new file mode 100644 index 0000000000000000000000000000000000000000..3d67578f25d1a03303a5c263892f7cdbcc814835 --- /dev/null +++ b/weclone/utils/length_cdf.py @@ -0,0 +1,73 @@ +# Copyright 2025 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections import defaultdict + +import fire +from tqdm import tqdm +from weclone.utils.log import logger + +from llamafactory.data import get_dataset, get_template_and_fix_tokenizer +from llamafactory.hparams import get_train_args +from llamafactory.model import load_tokenizer + + +def length_cdf( + model_name_or_path: str = "./Qwen2.5-7B-Instruct", + dataset: str = "wechat-sft", + dataset_dir: str = "./dataset/res_csv/sft", + template: str = "qwen", + interval: int = 256, +): + r"""Calculate the distribution of the input lengths in the dataset. + + Usage: export CUDA_VISIBLE_DEVICES=0 + python length_cdf.py --model_name_or_path path_to_model --dataset alpaca_en_demo --template default + """ + logger.info("开始计算cutoff_len......") + + model_args, data_args, training_args, _, _ = get_train_args( + { + "stage": "sft", + "model_name_or_path": model_name_or_path, + "dataset": dataset, + "dataset_dir": dataset_dir, + "template": template, + "cutoff_len": 1_000_000, + "preprocessing_num_workers": 16, + "output_dir": "dummy_dir", + "overwrite_cache": True, + "do_train": True, + } + ) + tokenizer_module = load_tokenizer(model_args) + template = get_template_and_fix_tokenizer(tokenizer_module["tokenizer"], data_args) # type: ignore + trainset = get_dataset(template, model_args, data_args, training_args, "sft", **tokenizer_module)["train_dataset"] # type: ignore + total_num = len(trainset) # type: ignore + length_dict = defaultdict(int) + for sample in tqdm(trainset["input_ids"], desc="Collecting lengths"): # type: ignore + length_dict[len(sample) // interval * interval] += 1 + + length_tuples = list(length_dict.items()) + length_tuples.sort() + count_accu, prob_accu = 0, 0 + logger.info(" cutoff_len设置建议:") + for length, count in length_tuples: + count_accu += count + prob_accu += count / total_num * 100 + logger.success(f"{count_accu:d} ({prob_accu:.2f}%) samples have length < {length + interval}.") + + +if __name__ == "__main__": + fire.Fire(length_cdf) diff --git a/weclone/utils/log.py b/weclone/utils/log.py new file mode 100644 index 0000000000000000000000000000000000000000..f01a158064b5ee0d4e6cf395ffd6747a2c85bdc9 --- /dev/null +++ b/weclone/utils/log.py @@ -0,0 +1,91 @@ +from loguru import logger +import sys +from functools import wraps + +logger.remove() + +logger.add( + sys.stderr, + format="[WeClone] {level.name[0]} | {time:HH:mm:ss} | {message}", + colorize=True, + level="INFO", +) + +logger.add( + "logs/weclone.log", # 日志文件路径 + rotation="1 day", # 每天轮换一个新的日志文件 + retention="7 days", # 保留最近7天的日志文件 + compression="zip", # 压缩旧的日志文件 + level="DEBUG", # 文件日志级别 + format="{time:YYYY-MM-DD HH:mm:ss.SSS} | {level: <8} | {name}:{function}:{line} - {message}", # 日志格式 + encoding="utf-8", # 文件编码 + enqueue=True, # 异步写入,避免阻塞 +) + + +def capture_output(func): + @wraps(func) + def wrapper(*args, **kwargs): + log_sink_buffer = [] + + def list_sink(message): + log_sink_buffer.append(message.record["message"]) + + sink_id = logger.add(list_sink, format="{message}", level="INFO") + + original_stdout = sys.stdout + original_stderr = sys.stderr + + class OutputTeeToGlobalLog: + def __init__(self, original_stream, log_method): + self.original_stream = original_stream + self.log_method = log_method + self.current_line_content = "" # Represents the current state of the line to be logged + + def write(self, data_chunk): + self.original_stream.write(data_chunk) # Pass through to console + + if data_chunk.endswith("\\r") and "\\n" not in data_chunk: + self.current_line_content = data_chunk[:-1] # Store without the trailing \\r + return + + full_buffer = self.current_line_content + data_chunk + lines_to_process = full_buffer.split("\\n") + + for i in range(len(lines_to_process) - 1): + line = lines_to_process[i] + final_content_of_line = line + last_cr = line.rfind("\\r") + if last_cr != -1: + final_content_of_line = line[last_cr + 1 :] + + escaped_log = final_content_of_line.replace("{", "{{").replace("}", "}}") + if final_content_of_line.strip() or line: + self.log_method(escaped_log, raw=True) + + self.current_line_content = lines_to_process[-1] + + def flush(self): + self.original_stream.flush() + if self.current_line_content: + final_content_of_line = self.current_line_content + last_cr = self.current_line_content.rfind("\\r") + if last_cr != -1: + final_content_of_line = self.current_line_content[last_cr + 1 :] + + escaped_log = final_content_of_line.replace("{", "{{").replace("}", "}}") + if final_content_of_line.strip() or self.current_line_content: + self.log_method(escaped_log, raw=True) + self.current_line_content = "" + + sys.stdout = OutputTeeToGlobalLog(original_stdout, logger.opt(raw=True).info) + sys.stderr = OutputTeeToGlobalLog(original_stderr, logger.opt(raw=True).error) + + try: + func(*args, **kwargs) + finally: + sys.stdout = original_stdout + sys.stderr = original_stderr + logger.remove(sink_id) + + return wrapper diff --git a/weclone/utils/tools.py b/weclone/utils/tools.py new file mode 100644 index 0000000000000000000000000000000000000000..d10ecf55e20d31dde48346632e0f8ede76205da4 --- /dev/null +++ b/weclone/utils/tools.py @@ -0,0 +1,9 @@ +def dict_to_argv(d): + argv = [] + for k, v in d.items(): + argv.append("--" + k) + if v is not None: + argv.append(str(v)) + return argv + +