fix pytorch 2.1.0 build, add multipack docs (#722)
Browse files- .github/workflows/main.yml +1 -0
- docker/Dockerfile +4 -0
- docs/multipack.md +51 -0
.github/workflows/main.yml
CHANGED
|
@@ -51,6 +51,7 @@ jobs:
|
|
| 51 |
build-args: |
|
| 52 |
BASE_TAG=${{ github.ref_name }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}
|
| 53 |
CUDA=${{ matrix.cuda }}
|
|
|
|
| 54 |
file: ./docker/Dockerfile
|
| 55 |
push: ${{ github.event_name != 'pull_request' }}
|
| 56 |
tags: ${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
|
|
|
|
| 51 |
build-args: |
|
| 52 |
BASE_TAG=${{ github.ref_name }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}
|
| 53 |
CUDA=${{ matrix.cuda }}
|
| 54 |
+
PYTORCH_VERSION=${{ matrix.pytorch }}
|
| 55 |
file: ./docker/Dockerfile
|
| 56 |
push: ${{ github.event_name != 'pull_request' }}
|
| 57 |
tags: ${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
|
docker/Dockerfile
CHANGED
|
@@ -5,6 +5,9 @@ ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX"
|
|
| 5 |
ARG AXOLOTL_EXTRAS=""
|
| 6 |
ARG CUDA="118"
|
| 7 |
ENV BNB_CUDA_VERSION=$CUDA
|
|
|
|
|
|
|
|
|
|
| 8 |
|
| 9 |
RUN apt-get update && \
|
| 10 |
apt-get install -y vim curl
|
|
@@ -16,6 +19,7 @@ RUN git clone --depth=1 https://github.com/OpenAccess-AI-Collective/axolotl.git
|
|
| 16 |
WORKDIR /workspace/axolotl
|
| 17 |
|
| 18 |
# If AXOLOTL_EXTRAS is set, append it in brackets
|
|
|
|
| 19 |
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
| 20 |
pip install -e .[flash-attn,$AXOLOTL_EXTRAS]; \
|
| 21 |
else \
|
|
|
|
| 5 |
ARG AXOLOTL_EXTRAS=""
|
| 6 |
ARG CUDA="118"
|
| 7 |
ENV BNB_CUDA_VERSION=$CUDA
|
| 8 |
+
ARG PYTORCH_VERSION="2.0.1"
|
| 9 |
+
|
| 10 |
+
ENV PYTORCH_VERSION=$PYTORCH_VERSION
|
| 11 |
|
| 12 |
RUN apt-get update && \
|
| 13 |
apt-get install -y vim curl
|
|
|
|
| 19 |
WORKDIR /workspace/axolotl
|
| 20 |
|
| 21 |
# If AXOLOTL_EXTRAS is set, append it in brackets
|
| 22 |
+
RUN sed -i "s/torch==.*/torch==$PYTORCH_VERSION/" requirements.txt
|
| 23 |
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
| 24 |
pip install -e .[flash-attn,$AXOLOTL_EXTRAS]; \
|
| 25 |
else \
|
docs/multipack.md
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Multipack
|
| 2 |
+
|
| 3 |
+
4k context, bsz =4,
|
| 4 |
+
each character represents 256 tokens
|
| 5 |
+
X represents a padding token
|
| 6 |
+
|
| 7 |
+
```
|
| 8 |
+
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5
|
| 9 |
+
[[ A A A A A A A A A A A ]
|
| 10 |
+
B B B B B B ]
|
| 11 |
+
C C C C C C C ]
|
| 12 |
+
D D D D ]]
|
| 13 |
+
|
| 14 |
+
[[ E E E E E E E E ]
|
| 15 |
+
[ F F F F ]
|
| 16 |
+
[ G G G ]
|
| 17 |
+
[ H H H H ]]
|
| 18 |
+
|
| 19 |
+
[[ I I I ]
|
| 20 |
+
[ J J J ]
|
| 21 |
+
[ K K K K K]
|
| 22 |
+
[ L L L ]]
|
| 23 |
+
```
|
| 24 |
+
|
| 25 |
+
after padding to longest input in each step
|
| 26 |
+
```
|
| 27 |
+
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5
|
| 28 |
+
[[ A A A A A A A A A A A ]
|
| 29 |
+
B B B B B B X X X X X X ]
|
| 30 |
+
C C C C C C C X X X X ]
|
| 31 |
+
D D D D X X X X X X X ]]
|
| 32 |
+
|
| 33 |
+
[[ E E E E E E E E ]
|
| 34 |
+
[ F F F F X X X X ]
|
| 35 |
+
[ G G G X X X X X ]
|
| 36 |
+
[ H H H H X X X X ]]
|
| 37 |
+
|
| 38 |
+
[[ I I I X X ]
|
| 39 |
+
[ J J J X X ]
|
| 40 |
+
[ K K K K K ]
|
| 41 |
+
[ L L L X X ]]
|
| 42 |
+
```
|
| 43 |
+
|
| 44 |
+
w packing ( note it's the same effective number of tokens per step, but a true bsz of 1)
|
| 45 |
+
```
|
| 46 |
+
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5
|
| 47 |
+
[[ A A A A A A A A A A A B B B B B
|
| 48 |
+
B C C C C C C C D D D D E E E E
|
| 49 |
+
E E E E F F F F F G G G H H H H
|
| 50 |
+
I I I J J J J K K K K K L L L X ]]
|
| 51 |
+
```
|