Spaces:
Sleeping
Sleeping
| In [4]: | |
| ! pip install -U git+https://github.com/huggingface/transformers.git | |
| ! pip install -U git+https://github.com/huggingface/accelerate.git | |
| Collecting git+https://github.com/huggingface/transformers.git | |
| Cloning https://github.com/huggingface/transformers.git to /tmp/pip-req-build-srwrto6l | |
| Running command git clone --filter=blob:none --quiet https://github.com/huggingface/transformers.git /tmp/pip-req-build-srwrto6l | |
| Traceback (most recent call last): | |
| File "/usr/local/lib/python3.10/dist-packages/pip/_internal/cli/base_command.py", line 179, in exc_logging_wrapper | |
| status = run_func(*args) | |
| File "/usr/local/lib/python3.10/dist-packages/pip/_internal/cli/req_command.py", line 67, in wrapper | |
| return func(self, options, args) | |
| File "/usr/local/lib/python3.10/dist-packages/pip/_internal/commands/install.py", line 377, in run | |
| requirement_set = resolver.resolve( | |
| File "/usr/local/lib/python3.10/dist-packages/pip/_internal/resolution/resolvelib/resolver.py", line 76, in resolve | |
| collected = self.factory.collect_root_requirements(root_reqs) | |
| File "/usr/local/lib/python3.10/dist-packages/pip/_internal/resolution/resolvelib/factory.py", line 538, in collect_root_requirements | |
| reqs = list( | |
| File "/usr/local/lib/python3.10/dist-packages/pip/_internal/resolution/resolvelib/factory.py", line 494, in _make_requirements_from_install_req | |
| cand = self._make_base_candidate_from_link( | |
| File "/usr/local/lib/python3.10/dist-packages/pip/_internal/resolution/resolvelib/factory.py", line 231, in _make_base_candidate_from_link | |
| self._link_candidate_cache[link] = LinkCandidate( | |
| File "/usr/local/lib/python3.10/dist-packages/pip/_internal/resolution/resolvelib/candidates.py", line 303, in __init__ | |
| super().__init__( | |
| File "/usr/local/lib/python3.10/dist-packages/pip/_internal/resolution/resolvelib/candidates.py", line 158, in __init__ | |
| self.dist = self._prepare() | |
| File "/usr/local/lib/python3.10/dist-packages/pip/_internal/resolution/resolvelib/candidates.py", line 235, in _prepare | |
| dist = self._prepare_distribution() | |
| File "/usr/local/lib/python3.10/dist-packages/pip/_internal/resolution/resolvelib/candidates.py", line 314, in _prepare_distribution | |
| return preparer.prepare_linked_requirement(self._ireq, parallel_builds=True) | |
| File "/usr/local/lib/python3.10/dist-packages/pip/_internal/operations/prepare.py", line 527, in prepare_linked_requirement | |
| return self._prepare_linked_requirement(req, parallel_builds) | |
| File "/usr/local/lib/python3.10/dist-packages/pip/_internal/operations/prepare.py", line 598, in _prepare_linked_requirement | |
| local_file = unpack_url( | |
| File "/usr/local/lib/python3.10/dist-packages/pip/_internal/operations/prepare.py", line 159, in unpack_url | |
| unpack_vcs_link(link, location, verbosity=verbosity) | |
| File "/usr/local/lib/python3.10/dist-packages/pip/_internal/operations/prepare.py", line 81, in unpack_vcs_link | |
| vcs_backend.unpack(location, url=hide_url(link.url), verbosity=verbosity) | |
| File "/usr/local/lib/python3.10/dist-packages/pip/_internal/vcs/versioncontrol.py", line 589, in unpack | |
| self.obtain(location, url=url, verbosity=verbosity) | |
| File "/usr/local/lib/python3.10/dist-packages/pip/_internal/vcs/versioncontrol.py", line 502, in obtain | |
| self.fetch_new(dest, url, rev_options, verbosity=verbosity) | |
| File "/usr/local/lib/python3.10/dist-packages/pip/_internal/vcs/git.py", line 277, in fetch_new | |
| self.run_command( | |
| File "/usr/local/lib/python3.10/dist-packages/pip/_internal/vcs/versioncontrol.py", line 631, in run_command | |
| return call_subprocess( | |
| File "/usr/local/lib/python3.10/dist-packages/pip/_internal/utils/subprocess.py", line 151, in call_subprocess | |
| line: str = proc.stdout.readline() | |
| KeyboardInterrupt | |
| During handling of the above exception, another exception occurred: | |
| Traceback (most recent call last): | |
| File "/usr/lib/python3.10/logging/__init__.py", line 1732, in isEnabledFor | |
| return self._cache[level] | |
| KeyError: 50 | |
| During handling of the above exception, another exception occurred: | |
| Traceback (most recent call last): | |
| File "/usr/local/bin/pip3", line 8, in <module> | |
| sys.exit(main()) | |
| File "/usr/local/lib/python3.10/dist-packages/pip/_internal/cli/main.py", line 80, in main | |
| return command.main(cmd_args) | |
| File "/usr/local/lib/python3.10/dist-packages/pip/_internal/cli/base_command.py", line 100, in main | |
| return self._main(args) | |
| File "/usr/local/lib/python3.10/dist-packages/pip/_internal/cli/base_command.py", line 232, in _main | |
| return run(options, args) | |
| File "/usr/local/lib/python3.10/dist-packages/pip/_internal/cli/base_command.py", line 215, in exc_logging_wrapper | |
| logger.critical("Operation cancelled by user") | |
| File "/usr/lib/python3.10/logging/__init__.py", line 1523, in critical | |
| if self.isEnabledFor(CRITICAL): | |
| File "/usr/lib/python3.10/logging/__init__.py", line 1740, in isEnabledFor | |
| level >= self.getEffectiveLevel() | |
| File "/usr/lib/python3.10/logging/__init__.py", line 1710, in getEffectiveLevel | |
| def getEffectiveLevel(self): | |
| KeyboardInterrupt | |
| ^C | |
| Collecting git+https://github.com/huggingface/accelerate.git | |
| Cloning https://github.com/huggingface/accelerate.git to /tmp/pip-req-build-o07w0b0ye | |
| Running command git clone --filter=blob:none --quiet https://github.com/huggingface/accelerate.git /tmp/pip-req-build-07w0b0ye | |
| Resolved https://github.com/huggingface/accelerate.git to commit 0e61127b5a6f99df51dd66803cd163f07bb85104 | |
| Installing build dependencies ... done | |
| Getting requirements to build wheel ... done | |
| Preparing metadata (pyproject.toml) ... done | |
| Requirement already satisfied: numpy<3.0.0,>=1.17 in /usr/local/lib/python3.10/dist-packages (from accelerate==1.1.0.dev0) (1.26.4) | |
| Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/dist-packages (from accelerate==1.1.0.dev0) (24.1) | |
| Requirement already satisfied: psutil in /usr/local/lib/python3.10/dist-packages (from accelerate==1.1.0.dev0) (5.9.5) | |
| Requirement already satisfied: pyyaml in /usr/local/lib/python3.10/dist-packages (from accelerate==1.1.0.dev0) (6.0.2) | |
| Requirement already satisfied: torch>=1.10.0 in /usr/local/lib/python3.10/dist-packages (from accelerate==1.1.0.dev0) (2.4.1+cu121) | |
| Requirement already satisfied: huggingface-hub>=0.21.0 in /usr/local/lib/python3.10/dist-packages (from accelerate==1.1.0.dev0) (0.24.7) | |
| Requirement already satisfied: safetensors>=0.4.3 in /usr/local/lib/python3.10/dist-packages (from accelerate==1.1.0.dev0) (0.4.5) | |
| Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from huggingface-hub>=0.21.0->accelerate==1.1.0.dev0) (3.16.1) | |
| Requirement already satisfied: fsspec>=2023.5.0 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub>=0.21.0->accelerate==1.1.0.dev0) (2024.6.1) | |
| Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from huggingface-hub>=0.21.0->accelerate==1.1.0.dev0) (2.32.3) | |
| Requirement already satisfied: tqdm>=4.42.1 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub>=0.21.0->accelerate==1.1.0.dev0) (4.66.5) | |
| Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub>=0.21.0->accelerate==1.1.0.dev0) (4.12.2) | |
| Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->accelerate==1.1.0.dev0) (1.13.3) | |
| Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->accelerate==1.1.0.dev0) (3.3) | |
| Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->accelerate==1.1.0.dev0) (3.1.4) | |
| Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch>=1.10.0->accelerate==1.1.0.dev0) (2.1.5) | |
| Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests->huggingface-hub>=0.21.0->accelerate==1.1.0.dev0) (3.3.2) | |
| Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->huggingface-hub>=0.21.0->accelerate==1.1.0.dev0) (3.10) | |
| Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->huggingface-hub>=0.21.0->accelerate==1.1.0.dev0) (2.2.3) | |
| Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->huggingface-hub>=0.21.0->accelerate==1.1.0.dev0) (2024.8.30) | |
| Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.10/dist-packages (from sympy->torch>=1.10.0->accelerate==1.1.0.dev0) (1.3.0) | |
| Building wheels for collected packages: accelerate | |
| Building wheel for accelerate (pyproject.toml) ... done | |
| Created wheel for accelerate: filename=accelerate-1.1.0.dev0-py3-none-any.whl size=332417 sha256=5fb12ca3a9ceac357ee327959835b13f847381455bd8cfd6e8d80d23b3fc25be | |
| Stored in directory: /tmp/pip-ephem-wheel-cache-le3a9ije/wheels/9c/a3/1e/47368f9b6575655fe9ee1b6350cfa7d4b0befe66a35f8a8365 | |
| Successfully built accelerate | |
| Installing collected packages: accelerate | |
| Attempting uninstall: accelerate | |
| Found existing installation: accelerate 0.34.2 | |
| Uninstalling accelerate-0.34.2: | |
| Successfully uninstalled accelerate-0.34.2 | |
| Successfully installed accelerate-1.1.0.dev0 | |
| In [5]: | |
| ! pip install datasets | |
| Collecting datasets | |
| Downloading datasets-3.0.1-py3-none-any.whl.metadata (20 kB) | |
| Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from datasets) (3.16.1) | |
| Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.10/dist-packages (from datasets) (1.26.4) | |
| Requirement already satisfied: pyarrow>=15.0.0 in /usr/local/lib/python3.10/dist-packages (from datasets) (16.1.0) | |
| Collecting dill<0.3.9,>=0.3.0 (from datasets) | |
| Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB) | |
| Requirement already satisfied: pandas in /usr/local/lib/python3.10/dist-packages (from datasets) (2.2.2) | |
| Requirement already satisfied: requests>=2.32.2 in /usr/local/lib/python3.10/dist-packages (from datasets) (2.32.3) | |
| Requirement already satisfied: tqdm>=4.66.3 in /usr/local/lib/python3.10/dist-packages (from datasets) (4.66.5) | |
| Collecting xxhash (from datasets) | |
| Downloading xxhash-3.5.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB) | |
| Collecting multiprocess (from datasets) | |
| Downloading multiprocess-0.70.17-py310-none-any.whl.metadata (7.2 kB) | |
| Requirement already satisfied: fsspec<=2024.6.1,>=2023.1.0 in /usr/local/lib/python3.10/dist-packages (from fsspec[http]<=2024.6.1,>=2023.1.0->datasets) (2024.6.1) | |
| Requirement already satisfied: aiohttp in /usr/local/lib/python3.10/dist-packages (from datasets) (3.10.9) | |
| Requirement already satisfied: huggingface-hub>=0.22.0 in /usr/local/lib/python3.10/dist-packages (from datasets) (0.24.7) | |
| Requirement already satisfied: packaging in /usr/local/lib/python3.10/dist-packages (from datasets) (24.1) | |
| Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packages (from datasets) (6.0.2) | |
| Requirement already satisfied: aiohappyeyeballs>=2.3.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (2.4.3) | |
| Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.3.1) | |
| Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (24.2.0) | |
| Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.4.1) | |
| Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (6.1.0) | |
| Requirement already satisfied: yarl<2.0,>=1.12.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.13.1) | |
| Requirement already satisfied: async-timeout<5.0,>=4.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (4.0.3) | |
| Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub>=0.22.0->datasets) (4.12.2) | |
| Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests>=2.32.2->datasets) (3.3.2) | |
| Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests>=2.32.2->datasets) (3.10) | |
| Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests>=2.32.2->datasets) (2.2.3) | |
| Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests>=2.32.2->datasets) (2024.8.30) | |
| INFO: pip is looking at multiple versions of multiprocess to determine which version is compatible with other requirements. This could take a while. | |
| Downloading multiprocess-0.70.16-py310-none-any.whl.metadata (7.2 kB) | |
| Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets) (2.8.2) | |
| Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets) (2024.2) | |
| Requirement already satisfied: tzdata>=2022.7 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets) (2024.2) | |
| Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.10/dist-packages (from python-dateutil>=2.8.2->pandas->datasets) (1.16.0) | |
| Downloading datasets-3.0.1-py3-none-any.whl (471 kB) | |
| ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 471.6/471.6 kB 24.2 MB/s eta 0:00:00 | |
| Downloading dill-0.3.8-py3-none-any.whl (116 kB) | |
| ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 116.3/116.3 kB 9.8 MB/s eta 0:00:00 | |
| Downloading multiprocess-0.70.16-py310-none-any.whl (134 kB) | |
| ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 134.8/134.8 kB 11.3 MB/s eta 0:00:00 | |
| Downloading xxhash-3.5.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (194 kB) | |
| ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 194.1/194.1 kB 15.2 MB/s eta 0:00:00 | |
| Installing collected packages: xxhash, dill, multiprocess, datasets | |
| Successfully installed datasets-3.0.1 dill-0.3.8 multiprocess-0.70.16 xxhash-3.5.0 | |
| In [6]: | |
| model_checkpoint = "microsoft/resnet-50" | |
| batch_size = 128 | |
| In [7]: | |
| from datasets import load_dataset | |
| In [8]: | |
| !pip install evaluate | |
| Collecting evaluate | |
| Downloading evaluate-0.4.3-py3-none-any.whl.metadata (9.2 kB) | |
| Requirement already satisfied: datasets>=2.0.0 in /usr/local/lib/python3.10/dist-packages (from evaluate) (3.0.1) | |
| Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.10/dist-packages (from evaluate) (1.26.4) | |
| Requirement already satisfied: dill in /usr/local/lib/python3.10/dist-packages (from evaluate) (0.3.8) | |
| Requirement already satisfied: pandas in /usr/local/lib/python3.10/dist-packages (from evaluate) (2.2.2) | |
| Requirement already satisfied: requests>=2.19.0 in /usr/local/lib/python3.10/dist-packages (from evaluate) (2.32.3) | |
| Requirement already satisfied: tqdm>=4.62.1 in /usr/local/lib/python3.10/dist-packages (from evaluate) (4.66.5) | |
| Requirement already satisfied: xxhash in /usr/local/lib/python3.10/dist-packages (from evaluate) (3.5.0) | |
| Requirement already satisfied: multiprocess in /usr/local/lib/python3.10/dist-packages (from evaluate) (0.70.16) | |
| Requirement already satisfied: fsspec>=2021.05.0 in /usr/local/lib/python3.10/dist-packages (from fsspec[http]>=2021.05.0->evaluate) (2024.6.1) | |
| Requirement already satisfied: huggingface-hub>=0.7.0 in /usr/local/lib/python3.10/dist-packages (from evaluate) (0.24.7) | |
| Requirement already satisfied: packaging in /usr/local/lib/python3.10/dist-packages (from evaluate) (24.1) | |
| Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from datasets>=2.0.0->evaluate) (3.16.1) | |
| Requirement already satisfied: pyarrow>=15.0.0 in /usr/local/lib/python3.10/dist-packages (from datasets>=2.0.0->evaluate) (16.1.0) | |
| Requirement already satisfied: aiohttp in /usr/local/lib/python3.10/dist-packages (from datasets>=2.0.0->evaluate) (3.10.9) | |
| Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packages (from datasets>=2.0.0->evaluate) (6.0.2) | |
| Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub>=0.7.0->evaluate) (4.12.2) | |
| Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->evaluate) (3.3.2) | |
| Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->evaluate) (3.10) | |
| Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->evaluate) (2.2.3) | |
| Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->evaluate) (2024.8.30) | |
| Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.10/dist-packages (from pandas->evaluate) (2.8.2) | |
| Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas->evaluate) (2024.2) | |
| Requirement already satisfied: tzdata>=2022.7 in /usr/local/lib/python3.10/dist-packages (from pandas->evaluate) (2024.2) | |
| Requirement already satisfied: aiohappyeyeballs>=2.3.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets>=2.0.0->evaluate) (2.4.3) | |
| Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets>=2.0.0->evaluate) (1.3.1) | |
| Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets>=2.0.0->evaluate) (24.2.0) | |
| Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets>=2.0.0->evaluate) (1.4.1) | |
| Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets>=2.0.0->evaluate) (6.1.0) | |
| Requirement already satisfied: yarl<2.0,>=1.12.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets>=2.0.0->evaluate) (1.13.1) | |
| Requirement already satisfied: async-timeout<5.0,>=4.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets>=2.0.0->evaluate) (4.0.3) | |
| Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.10/dist-packages (from python-dateutil>=2.8.2->pandas->evaluate) (1.16.0) | |
| Downloading evaluate-0.4.3-py3-none-any.whl (84 kB) | |
| ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 84.0/84.0 kB 7.1 MB/s eta 0:00:00 | |
| Installing collected packages: evaluate | |
| Successfully installed evaluate-0.4.3 | |
| In [9]: | |
| from google.colab import drive | |
| drive.mount('/content/drive/') | |
| Drive already mounted at /content/drive/; to attempt to forcibly remount, call drive.mount("/content/drive/", force_remount=True). | |
| In [10]: | |
| from evaluate import load | |
| metric = load("accuracy") | |
| /usr/local/lib/python3.10/dist-packages/huggingface_hub/utils/_token.py:89: UserWarning: | |
| The secret `HF_TOKEN` does not exist in your Colab secrets. | |
| To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session. | |
| You will be able to reuse this secret in all of your notebooks. | |
| Please note that authentication is recommended but still optional to access public models or datasets. | |
| warnings.warn( | |
| Downloading builder script: 0%| | 0.00/4.20k [00:00<?, ?B/s] | |
| In [11]: | |
| dataset = load_dataset("imagefolder", data_dir="drive/MyDrive/Face Mask Dataset") | |
| labels = dataset["train"].features["label"].names | |
| label2id, id2label = dict(), dict() | |
| for i, label in enumerate(labels): | |
| label2id[label] = i | |
| id2label[i] = label | |
| Resolving data files: 0%| | 0/10000 [00:00<?, ?it/s] | |
| Downloading data: 0%| | 0/10000 [00:00<?, ?files/s] | |
| Generating train split: 0 examples [00:00, ? examples/s] | |
| 新段落¶ | |
| In [12]: | |
| from transformers import AutoImageProcessor | |
| image_processor = AutoImageProcessor.from_pretrained(model_checkpoint) | |
| image_processor | |
| preprocessor_config.json: 0%| | 0.00/266 [00:00<?, ?B/s] | |
| Out[12]: | |
| ConvNextImageProcessor { | |
| "crop_pct": 0.875, | |
| "do_normalize": true, | |
| "do_rescale": true, | |
| "do_resize": true, | |
| "image_mean": [ | |
| 0.485, | |
| 0.456, | |
| 0.406 | |
| ], | |
| "image_processor_type": "ConvNextImageProcessor", | |
| "image_std": [ | |
| 0.229, | |
| 0.224, | |
| 0.225 | |
| ], | |
| "resample": 3, | |
| "rescale_factor": 0.00392156862745098, | |
| "size": { | |
| "shortest_edge": 224 | |
| } | |
| } | |
| In [13]: | |
| from torchvision.transforms import ( | |
| CenterCrop, | |
| Compose, | |
| Normalize, | |
| RandomHorizontalFlip, | |
| RandomResizedCrop, | |
| Resize, | |
| ToTensor, | |
| ColorJitter, | |
| RandomRotation | |
| ) | |
| normalize = Normalize(mean=image_processor.image_mean, std=image_processor.image_std) | |
| if "height" in image_processor.size: | |
| size = (image_processor.size["height"], image_processor.size["width"]) | |
| crop_size = size | |
| max_size = None | |
| elif "shortest_edge" in image_processor.size: | |
| size = image_processor.size["shortest_edge"] | |
| crop_size = (size, size) | |
| max_size = image_processor.size.get("longest_edge") | |
| train_transforms = Compose( | |
| [ | |
| RandomResizedCrop(crop_size), | |
| RandomHorizontalFlip(), | |
| RandomRotation(degrees=15), | |
| ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1), | |
| ToTensor(), | |
| normalize, | |
| ] | |
| ) | |
| val_transforms = Compose( | |
| [ | |
| Resize(size), | |
| CenterCrop(crop_size), | |
| RandomRotation(degrees=15), | |
| ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1), | |
| ToTensor(), | |
| normalize, | |
| ] | |
| ) | |
| def preprocess_train(example_batch): | |
| example_batch["pixel_values"] = [ | |
| train_transforms(image.convert("RGB")) for image in example_batch["image"] | |
| ] | |
| return example_batch | |
| def preprocess_val(example_batch): | |
| example_batch["pixel_values"] = [val_transforms(image.convert("RGB")) for image in example_batch["image"]] | |
| return example_batch | |
| In [15]: | |
| splits = dataset["train"].train_test_split(test_size=0.3) | |
| train_ds = splits['train'] | |
| val_ds = splits['test'] | |
| train_ds.set_transform(preprocess_train) | |
| val_ds.set_transform(preprocess_val) | |
| In [16]: | |
| from transformers import AutoModelForImageClassification, TrainingArguments, Trainer | |
| model = AutoModelForImageClassification.from_pretrained(model_checkpoint, | |
| label2id=label2id, | |
| id2label=id2label, | |
| ignore_mismatched_sizes = True) | |
| config.json: 0%| | 0.00/69.6k [00:00<?, ?B/s] | |
| model.safetensors: 0%| | 0.00/102M [00:00<?, ?B/s] | |
| Some weights of ResNetForImageClassification were not initialized from the model checkpoint at microsoft/resnet-50 and are newly initialized because the shapes did not match: | |
| - classifier.1.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([2]) in the model instantiated | |
| - classifier.1.weight: found shape torch.Size([1000, 2048]) in the checkpoint and torch.Size([2, 2048]) in the model instantiated | |
| You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference. | |
| In [23]: | |
| model_name = model_checkpoint.split("/")[-1] | |
| args = TrainingArguments( | |
| f"{model_name}-finetuned", | |
| remove_unused_columns=False, | |
| evaluation_strategy = "epoch", | |
| save_strategy = "epoch", | |
| save_total_limit = 5, | |
| learning_rate=1e-3, | |
| per_device_train_batch_size=batch_size, | |
| gradient_accumulation_steps=2, | |
| per_device_eval_batch_size=batch_size, | |
| num_train_epochs=2, | |
| warmup_ratio=0.1, | |
| weight_decay=0.01, | |
| lr_scheduler_type="cosine", | |
| logging_steps=10, | |
| load_best_model_at_end=True, | |
| metric_for_best_model="accuracy",) | |
| /usr/local/lib/python3.10/dist-packages/transformers/training_args.py:1525: FutureWarning: `evaluation_strategy` is deprecated and will be removed in version 4.46 of 🤗 Transformers. Use `eval_strategy` instead | |
| warnings.warn( | |
| In [19]: | |
| import numpy as np | |
| def compute_metrics(eval_pred): | |
| """Computes accuracy on a batch of predictions""" | |
| predictions = np.argmax(eval_pred.predictions, axis=1) | |
| return metric.compute(predictions=predictions, references=eval_pred.label_ids) | |
| In [20]: | |
| import torch | |
| def collate_fn(examples): | |
| pixel_values = torch.stack([example["pixel_values"] for example in examples]) | |
| labels = torch.tensor([example["label"] for example in examples]) | |
| return {"pixel_values": pixel_values, "labels": labels} | |
| In [24]: | |
| trainer = Trainer(model, | |
| args, | |
| train_dataset=train_ds, | |
| eval_dataset=val_ds, | |
| tokenizer=image_processor, | |
| compute_metrics=compute_metrics, | |
| data_collator=collate_fn,) | |
| In [25]: | |
| train_results = trainer.train() | |
| # 保存模型 | |
| trainer.save_model() | |
| trainer.log_metrics("train", train_results.metrics) | |
| trainer.save_metrics("train", train_results.metrics) | |
| trainer.save_state() | |
| [54/54 2:06:56, Epoch 1/2] | |
| Epoch Training Loss Validation Loss Accuracy | |
| 0 0.149200 0.036754 0.986667 | |
| 1 0.038400 0.020120 0.993333 | |
| ***** train metrics ***** | |
| epoch = 1.9636 | |
| total_flos = 272606314GF | |
| train_loss = 0.1626 | |
| train_runtime = 2:10:43.64 | |
| train_samples_per_second = 1.785 | |
| train_steps_per_second = 0.007 | |
| In [26]: | |
| metrics = trainer.evaluate() | |
| # some nice to haves: | |
| trainer.log_metrics("eval", metrics) | |
| trainer.save_metrics("eval", metrics) | |
| [24/24 00:40] | |
| ***** eval metrics ***** | |
| epoch = 1.9636 | |
| eval_accuracy = 0.9947 | |
| eval_loss = 0.0184 | |
| eval_runtime = 0:00:44.17 | |
| eval_samples_per_second = 67.914 | |
| eval_steps_per_second = 0.543 |