Yes, this sounds like the model just learns to not attend to earlier tokens when an eos token comes after. With a packed sequence mask you can enforce this explicitly by just masking out previous tokens.
Hey @shantanuagarwal, glad you enjoyed the article! Even though I havent tried it out myself you should be able to leverage pytorch flexattention api for this. Have a look at the tutorial here https://pytorch.org/blog/flexattention/. Section "Document Masking/Jagged Sequences" talks about these packed sequence masks.