Commit ccd1a19e authored by 孙傲's avatar 孙傲

repositories提前下好

parent 1c05717c
......@@ -4,7 +4,6 @@ __pycache__
*.pth
/ESRGAN/*
/SwinIR/*
/repositories
/venv
/tmp
/model.ckpt
......
# Comment line immediately above ownership line is reserved for related gus information. Please be careful while editing.
#ECCN:Open Source
# Salesforce Open Source Community Code of Conduct
## About the Code of Conduct
Equality is a core value at Salesforce. We believe a diverse and inclusive
community fosters innovation and creativity, and are committed to building a
culture where everyone feels included.
Salesforce open-source projects are committed to providing a friendly, safe, and
welcoming environment for all, regardless of gender identity and expression,
sexual orientation, disability, physical appearance, body size, ethnicity, nationality,
race, age, religion, level of experience, education, socioeconomic status, or
other similar personal characteristics.
The goal of this code of conduct is to specify a baseline standard of behavior so
that people with different social values and communication styles can work
together effectively, productively, and respectfully in our open source community.
It also establishes a mechanism for reporting issues and resolving conflicts.
All questions and reports of abusive, harassing, or otherwise unacceptable behavior
in a Salesforce open-source project may be reported by contacting the Salesforce
Open Source Conduct Committee at ossconduct@salesforce.com.
## Our Pledge
In the interest of fostering an open and welcoming environment, we as
contributors and maintainers pledge to making participation in our project and
our community a harassment-free experience for everyone, regardless of gender
identity and expression, sexual orientation, disability, physical appearance,
body size, ethnicity, nationality, race, age, religion, level of experience, education,
socioeconomic status, or other similar personal characteristics.
## Our Standards
Examples of behavior that contributes to creating a positive environment
include:
* Using welcoming and inclusive language
* Being respectful of differing viewpoints and experiences
* Gracefully accepting constructive criticism
* Focusing on what is best for the community
* Showing empathy toward other community members
Examples of unacceptable behavior by participants include:
* The use of sexualized language or imagery and unwelcome sexual attention or
advances
* Personal attacks, insulting/derogatory comments, or trolling
* Public or private harassment
* Publishing, or threatening to publish, others' private information—such as
a physical or electronic address—without explicit permission
* Other conduct which could reasonably be considered inappropriate in a
professional setting
* Advocating for or encouraging any of the above behaviors
## Our Responsibilities
Project maintainers are responsible for clarifying the standards of acceptable
behavior and are expected to take appropriate and fair corrective action in
response to any instances of unacceptable behavior.
Project maintainers have the right and responsibility to remove, edit, or
reject comments, commits, code, wiki edits, issues, and other contributions
that are not aligned with this Code of Conduct, or to ban temporarily or
permanently any contributor for other behaviors that they deem inappropriate,
threatening, offensive, or harmful.
## Scope
This Code of Conduct applies both within project spaces and in public spaces
when an individual is representing the project or its community. Examples of
representing a project or community include using an official project email
address, posting via an official social media account, or acting as an appointed
representative at an online or offline event. Representation of a project may be
further defined and clarified by project maintainers.
## Enforcement
Instances of abusive, harassing, or otherwise unacceptable behavior may be
reported by contacting the Salesforce Open Source Conduct Committee
at ossconduct@salesforce.com. All complaints will be reviewed and investigated
and will result in a response that is deemed necessary and appropriate to the
circumstances. The committee is obligated to maintain confidentiality with
regard to the reporter of an incident. Further details of specific enforcement
policies may be posted separately.
Project maintainers who do not follow or enforce the Code of Conduct in good
faith may face temporary or permanent repercussions as determined by other
members of the project's leadership and the Salesforce Open Source Conduct
Committee.
## Attribution
This Code of Conduct is adapted from the [Contributor Covenant][contributor-covenant-home],
version 1.4, available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html.
It includes adaptions and additions from [Go Community Code of Conduct][golang-coc],
[CNCF Code of Conduct][cncf-coc], and [Microsoft Open Source Code of Conduct][microsoft-coc].
This Code of Conduct is licensed under the [Creative Commons Attribution 3.0 License][cc-by-3-us].
[contributor-covenant-home]: https://www.contributor-covenant.org (https://www.contributor-covenant.org/)
[golang-coc]: https://golang.org/conduct
[cncf-coc]: https://github.com/cncf/foundation/blob/master/code-of-conduct.md
[microsoft-coc]: https://opensource.microsoft.com/codeofconduct/
[cc-by-3-us]: https://creativecommons.org/licenses/by/3.0/us/
Copyright (c) 2022, Salesforce.com, Inc.
All rights reserved.
Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
* Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
* Neither the name of Salesforce.com nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
## BLIP: Bootstrapping Language-Image Pre-training for Unified Vision-Language Understanding and Generation
<img src="BLIP.gif" width="700">
This is the PyTorch code of the <a href="https://arxiv.org/abs/2201.12086">BLIP paper</a> [[blog](https://blog.salesforceairesearch.com/blip-bootstrapping-language-image-pretraining/)]. The code has been tested on PyTorch 1.10.
To install the dependencies, run <pre/>pip install -r requirements.txt</pre>
Catalog:
- [x] Inference demo
- [x] Pre-trained and finetuned checkpoints
- [x] Finetuning code for Image-Text Retrieval, Image Captioning, VQA, and NLVR2
- [x] Pre-training code
- [x] Zero-shot video-text retrieval
- [x] Download of bootstrapped pre-training datasets
### Inference demo:
Run our interactive demo using [Colab notebook](https://colab.research.google.com/github/salesforce/BLIP/blob/main/demo.ipynb) (no GPU needed).
The demo includes code for:
1. Image captioning
2. Open-ended visual question answering
3. Multimodal / unimodal feature extraction
4. Image-text matching
Try out the [Web demo](https://huggingface.co/spaces/Salesforce/BLIP), integrated into [Huggingface Spaces 🤗](https://huggingface.co/spaces) using [Gradio](https://github.com/gradio-app/gradio).
Replicate web demo and Docker image is also available at [![Replicate](https://replicate.com/salesforce/blip/badge)](https://replicate.com/salesforce/blip)
### Pre-trained checkpoints:
Num. pre-train images | BLIP w/ ViT-B | BLIP w/ ViT-B and CapFilt-L | BLIP w/ ViT-L
--- | :---: | :---: | :---:
14M | <a href="https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_14M.pth">Download</a>| - | -
129M | <a href="https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base.pth">Download</a>| <a href="https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_capfilt_large.pth">Download</a> | <a href="https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large.pth">Download</a>
### Finetuned checkpoints:
Task | BLIP w/ ViT-B | BLIP w/ ViT-B and CapFilt-L | BLIP w/ ViT-L
--- | :---: | :---: | :---:
Image-Text Retrieval (COCO) | <a href="https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_retrieval_coco.pth">Download</a>| - | <a href="https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_retrieval_coco.pth">Download</a>
Image-Text Retrieval (Flickr30k) | <a href="https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_retrieval_flickr.pth">Download</a>| - | <a href="https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_retrieval_flickr.pth">Download</a>
Image Captioning (COCO) | - | <a href="https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_caption_capfilt_large.pth">Download</a>| <a href="https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_caption.pth">Download</a> |
VQA | <a href="https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_vqa.pth">Download</a>| <a href="https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_vqa_capfilt_large.pth">Download</a> | -
NLVR2 | <a href="https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_nlvr.pth">Download</a>| - | -
### Image-Text Retrieval:
1. Download COCO and Flickr30k datasets from the original websites, and set 'image_root' in configs/retrieval_{dataset}.yaml accordingly.
2. To evaluate the finetuned BLIP model on COCO, run:
<pre>python -m torch.distributed.run --nproc_per_node=8 train_retrieval.py \
--config ./configs/retrieval_coco.yaml \
--output_dir output/retrieval_coco \
--evaluate</pre>
3. To finetune the pre-trained checkpoint using 8 A100 GPUs, first set 'pretrained' in configs/retrieval_coco.yaml as "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base.pth". Then run:
<pre>python -m torch.distributed.run --nproc_per_node=8 train_retrieval.py \
--config ./configs/retrieval_coco.yaml \
--output_dir output/retrieval_coco </pre>
### Image-Text Captioning:
1. Download COCO and NoCaps datasets from the original websites, and set 'image_root' in configs/caption_coco.yaml and configs/nocaps.yaml accordingly.
2. To evaluate the finetuned BLIP model on COCO, run:
<pre>python -m torch.distributed.run --nproc_per_node=8 train_caption.py --evaluate</pre>
3. To evaluate the finetuned BLIP model on NoCaps, generate results with: (evaluation needs to be performed on official server)
<pre>python -m torch.distributed.run --nproc_per_node=8 eval_nocaps.py </pre>
4. To finetune the pre-trained checkpoint using 8 A100 GPUs, first set 'pretrained' in configs/caption_coco.yaml as "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_capfilt_large.pth". Then run:
<pre>python -m torch.distributed.run --nproc_per_node=8 train_caption.py </pre>
### VQA:
1. Download VQA v2 dataset and Visual Genome dataset from the original websites, and set 'vqa_root' and 'vg_root' in configs/vqa.yaml.
2. To evaluate the finetuned BLIP model, generate results with: (evaluation needs to be performed on official server)
<pre>python -m torch.distributed.run --nproc_per_node=8 train_vqa.py --evaluate</pre>
3. To finetune the pre-trained checkpoint using 16 A100 GPUs, first set 'pretrained' in configs/vqa.yaml as "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_capfilt_large.pth". Then run:
<pre>python -m torch.distributed.run --nproc_per_node=16 train_vqa.py </pre>
### NLVR2:
1. Download NLVR2 dataset from the original websites, and set 'image_root' in configs/nlvr.yaml.
2. To evaluate the finetuned BLIP model, run
<pre>python -m torch.distributed.run --nproc_per_node=8 train_nlvr.py --evaluate</pre>
3. To finetune the pre-trained checkpoint using 16 A100 GPUs, first set 'pretrained' in configs/nlvr.yaml as "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base.pth". Then run:
<pre>python -m torch.distributed.run --nproc_per_node=16 train_nlvr.py </pre>
### Finetune with ViT-L:
In order to finetune a model with ViT-L, simply change the config file to set 'vit' as large. Batch size and learning rate may also need to be adjusted accordingly (please see the paper's appendix for hyper-parameter details). <a href="https://github.com/facebookresearch/fairscale">Gradient checkpoint</a> can also be activated in the config file to reduce GPU memory usage.
### Pre-train:
1. Prepare training json files where each json file contains a list. Each item in the list is a dictonary with two key-value pairs: {'image': path_of_image, 'caption': text_of_image}.
2. In configs/pretrain.yaml, set 'train_file' as the paths for the json files .
3. Pre-train the model using 8 A100 GPUs:
<pre>python -m torch.distributed.run --nproc_per_node=8 pretrain.py --config ./configs/Pretrain.yaml --output_dir output/Pretrain </pre>
### Zero-shot video-text retrieval:
1. Download MSRVTT dataset following the instructions from https://github.com/salesforce/ALPRO, and set 'video_root' accordingly in configs/retrieval_msrvtt.yaml.
2. Install [decord](https://github.com/dmlc/decord) with <pre>pip install decord</pre>
3. To perform zero-shot evaluation, run
<pre>python -m torch.distributed.run --nproc_per_node=8 eval_retrieval_video.py</pre>
### Pre-training datasets download:
We provide bootstrapped pre-training datasets as json files. Each json file contains a list. Each item in the list is a dictonary with two key-value pairs: {'url': url_of_image, 'caption': text_of_image}.
Image source | Filtered web caption | Filtered synthetic caption by ViT-B | Filtered synthetic caption by ViT-L
--- | :---: | :---: | :---:
CC3M+CC12M+SBU | <a href="https://storage.googleapis.com/sfr-vision-language-research/BLIP/datasets/ccs_filtered.json">Download</a>| <a href="https://storage.googleapis.com/sfr-vision-language-research/BLIP/datasets/ccs_synthetic_filtered.json">Download</a>| <a href="https://storage.googleapis.com/sfr-vision-language-research/BLIP/datasets/ccs_synthetic_filtered_large.json">Download</a>
LAION115M | <a href="https://storage.googleapis.com/sfr-vision-language-research/BLIP/datasets/laion_filtered.json">Download</a>| <a href="https://storage.googleapis.com/sfr-vision-language-research/BLIP/datasets/laion_synthetic_filtered.json">Download</a>| <a href="https://storage.googleapis.com/sfr-vision-language-research/BLIP/datasets/laion_synthetic_filtered_large.json">Download</a>
### Citation
If you find this code to be useful for your research, please consider citing.
<pre>
@inproceedings{li2022blip,
title={BLIP: Bootstrapping Language-Image Pre-training for Unified Vision-Language Understanding and Generation},
author={Junnan Li and Dongxu Li and Caiming Xiong and Steven Hoi},
year={2022},
booktitle={ICML},
}</pre>
### Acknowledgement
The implementation of BLIP relies on resources from <a href="https://github.com/salesforce/ALBEF">ALBEF</a>, <a href="https://github.com/huggingface/transformers">Huggingface Transformers</a>, and <a href="https://github.com/rwightman/pytorch-image-models/tree/master/timm">timm</a>. We thank the original authors for their open-sourcing.
## Security
Please report any security issue to [security@salesforce.com](mailto:security@salesforce.com)
as soon as it is discovered. This library limits its runtime dependencies in
order to reduce the total cost of ownership as much as can be, but all consumers
should remain vigilant and have their security stakeholders review all third-party
products (3PP) like this one and their dependencies.
build:
gpu: true
cuda: "11.1"
python_version: "3.8"
system_packages:
- "libgl1-mesa-glx"
- "libglib2.0-0"
python_packages:
- "ipython==7.30.1"
- "torchvision==0.11.1"
- "torch==1.10.0"
- "timm==0.4.12"
- "transformers==4.15.0"
- "fairscale==0.4.4"
- "pycocoevalcap==1.2"
predict: "predict.py:Predictor"
{
"architectures": [
"BertModel"
],
"attention_probs_dropout_prob": 0.1,
"hidden_act": "gelu",
"hidden_dropout_prob": 0.1,
"hidden_size": 768,
"initializer_range": 0.02,
"intermediate_size": 3072,
"layer_norm_eps": 1e-12,
"max_position_embeddings": 512,
"model_type": "bert",
"num_attention_heads": 12,
"num_hidden_layers": 12,
"pad_token_id": 0,
"type_vocab_size": 2,
"vocab_size": 30522,
"encoder_width": 768,
"add_cross_attention": true
}
image_root: '/export/share/datasets/vision/coco/images/'
ann_root: 'annotation'
coco_gt_root: 'annotation/coco_gt'
# set pretrained as a file path or an url
pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_caption_capfilt_large.pth'
# size of vit model; base or large
vit: 'base'
vit_grad_ckpt: False
vit_ckpt_layer: 0
batch_size: 32
init_lr: 1e-5
# vit: 'large'
# vit_grad_ckpt: True
# vit_ckpt_layer: 5
# batch_size: 16
# init_lr: 2e-6
image_size: 384
# generation configs
max_length: 20
min_length: 5
num_beams: 3
prompt: 'a picture of '
# optimizer
weight_decay: 0.05
min_lr: 0
max_epoch: 5
{
"architectures": [
"BertModel"
],
"attention_probs_dropout_prob": 0.1,
"hidden_act": "gelu",
"hidden_dropout_prob": 0.1,
"hidden_size": 768,
"initializer_range": 0.02,
"intermediate_size": 3072,
"layer_norm_eps": 1e-12,
"max_position_embeddings": 512,
"model_type": "bert",
"num_attention_heads": 12,
"num_hidden_layers": 12,
"pad_token_id": 0,
"type_vocab_size": 2,
"vocab_size": 30524,
"encoder_width": 768,
"add_cross_attention": true
}
image_root: '/export/share/datasets/vision/NLVR2/'
ann_root: 'annotation'
# set pretrained as a file path or an url
pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_nlvr.pth'
#size of vit model; base or large
vit: 'base'
batch_size_train: 16
batch_size_test: 64
vit_grad_ckpt: False
vit_ckpt_layer: 0
max_epoch: 15
image_size: 384
# optimizer
weight_decay: 0.05
init_lr: 3e-5
min_lr: 0
image_root: '/export/share/datasets/vision/nocaps/'
ann_root: 'annotation'
# set pretrained as a file path or an url
pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_caption_capfilt_large.pth'
vit: 'base'
batch_size: 32
image_size: 384
max_length: 20
min_length: 5
num_beams: 3
prompt: 'a picture of '
\ No newline at end of file
train_file: ['/export/share/junnan-li/VL_pretrain/annotation/coco_karpathy_train.json',
'/export/share/junnan-li/VL_pretrain/annotation/vg_caption.json',
]
laion_path: ''
# size of vit model; base or large
vit: 'base'
vit_grad_ckpt: False
vit_ckpt_layer: 0
image_size: 224
batch_size: 75
queue_size: 57600
alpha: 0.4
# optimizer
weight_decay: 0.05
init_lr: 3e-4
min_lr: 1e-6
warmup_lr: 1e-6
lr_decay_rate: 0.9
max_epoch: 20
warmup_steps: 3000
image_root: '/export/share/datasets/vision/coco/images/'
ann_root: 'annotation'
dataset: 'coco'
# set pretrained as a file path or an url
pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_retrieval_coco.pth'
# size of vit model; base or large
vit: 'base'
batch_size_train: 32
batch_size_test: 64
vit_grad_ckpt: True
vit_ckpt_layer: 4
init_lr: 1e-5
# vit: 'large'
# batch_size_train: 16
# batch_size_test: 32
# vit_grad_ckpt: True
# vit_ckpt_layer: 12
# init_lr: 5e-6
image_size: 384
queue_size: 57600
alpha: 0.4
k_test: 256
negative_all_rank: True
# optimizer
weight_decay: 0.05
min_lr: 0
max_epoch: 6
image_root: '/export/share/datasets/vision/flickr30k/'
ann_root: 'annotation'
dataset: 'flickr'
# set pretrained as a file path or an url
pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_retrieval_flickr.pth'
# size of vit model; base or large
vit: 'base'
batch_size_train: 32
batch_size_test: 64
vit_grad_ckpt: True
vit_ckpt_layer: 4
init_lr: 1e-5
# vit: 'large'
# batch_size_train: 16
# batch_size_test: 32
# vit_grad_ckpt: True
# vit_ckpt_layer: 10
# init_lr: 5e-6
image_size: 384
queue_size: 57600
alpha: 0.4
k_test: 128
negative_all_rank: False
# optimizer
weight_decay: 0.05
min_lr: 0
max_epoch: 6
video_root: '/export/share/dongxuli/data/msrvtt_retrieval/videos'
ann_root: 'annotation'
# set pretrained as a file path or an url
pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_retrieval_coco.pth'
# size of vit model; base or large
vit: 'base'
batch_size: 64
k_test: 128
image_size: 384
num_frm_test: 8
\ No newline at end of file
vqa_root: '/export/share/datasets/vision/VQA/Images/mscoco/' #followed by train2014/
vg_root: '/export/share/datasets/vision/visual-genome/' #followed by image/
train_files: ['vqa_train','vqa_val','vg_qa']
ann_root: 'annotation'
# set pretrained as a file path or an url
pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_vqa_capfilt_large.pth'
# size of vit model; base or large
vit: 'base'
batch_size_train: 16
batch_size_test: 32
vit_grad_ckpt: False
vit_ckpt_layer: 0
init_lr: 2e-5
image_size: 480
k_test: 128
inference: 'rank'
# optimizer
weight_decay: 0.05
min_lr: 0
max_epoch: 10
\ No newline at end of file
import torch
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode
from data.coco_karpathy_dataset import coco_karpathy_train, coco_karpathy_caption_eval, coco_karpathy_retrieval_eval
from data.nocaps_dataset import nocaps_eval
from data.flickr30k_dataset import flickr30k_train, flickr30k_retrieval_eval
from data.vqa_dataset import vqa_dataset
from data.nlvr_dataset import nlvr_dataset
from data.pretrain_dataset import pretrain_dataset
from transform.randaugment import RandomAugment
def create_dataset(dataset, config, min_scale=0.5):
normalize = transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
transform_train = transforms.Compose([
transforms.RandomResizedCrop(config['image_size'],scale=(min_scale, 1.0),interpolation=InterpolationMode.BICUBIC),
transforms.RandomHorizontalFlip(),
RandomAugment(2,5,isPIL=True,augs=['Identity','AutoContrast','Brightness','Sharpness','Equalize',
'ShearX', 'ShearY', 'TranslateX', 'TranslateY', 'Rotate']),
transforms.ToTensor(),
normalize,
])
transform_test = transforms.Compose([
transforms.Resize((config['image_size'],config['image_size']),interpolation=InterpolationMode.BICUBIC),
transforms.ToTensor(),
normalize,
])
if dataset=='pretrain':
dataset = pretrain_dataset(config['train_file'], config['laion_path'], transform_train)
return dataset
elif dataset=='caption_coco':
train_dataset = coco_karpathy_train(transform_train, config['image_root'], config['ann_root'], prompt=config['prompt'])
val_dataset = coco_karpathy_caption_eval(transform_test, config['image_root'], config['ann_root'], 'val')
test_dataset = coco_karpathy_caption_eval(transform_test, config['image_root'], config['ann_root'], 'test')
return train_dataset, val_dataset, test_dataset
elif dataset=='nocaps':
val_dataset = nocaps_eval(transform_test, config['image_root'], config['ann_root'], 'val')
test_dataset = nocaps_eval(transform_test, config['image_root'], config['ann_root'], 'test')
return val_dataset, test_dataset
elif dataset=='retrieval_coco':
train_dataset = coco_karpathy_train(transform_train, config['image_root'], config['ann_root'])
val_dataset = coco_karpathy_retrieval_eval(transform_test, config['image_root'], config['ann_root'], 'val')
test_dataset = coco_karpathy_retrieval_eval(transform_test, config['image_root'], config['ann_root'], 'test')
return train_dataset, val_dataset, test_dataset
elif dataset=='retrieval_flickr':
train_dataset = flickr30k_train(transform_train, config['image_root'], config['ann_root'])
val_dataset = flickr30k_retrieval_eval(transform_test, config['image_root'], config['ann_root'], 'val')
test_dataset = flickr30k_retrieval_eval(transform_test, config['image_root'], config['ann_root'], 'test')
return train_dataset, val_dataset, test_dataset
elif dataset=='vqa':
train_dataset = vqa_dataset(transform_train, config['ann_root'], config['vqa_root'], config['vg_root'],
train_files = config['train_files'], split='train')
test_dataset = vqa_dataset(transform_test, config['ann_root'], config['vqa_root'], config['vg_root'], split='test')
return train_dataset, test_dataset
elif dataset=='nlvr':
train_dataset = nlvr_dataset(transform_train, config['image_root'], config['ann_root'],'train')
val_dataset = nlvr_dataset(transform_test, config['image_root'], config['ann_root'],'val')
test_dataset = nlvr_dataset(transform_test, config['image_root'], config['ann_root'],'test')
return train_dataset, val_dataset, test_dataset
def create_sampler(datasets, shuffles, num_tasks, global_rank):
samplers = []
for dataset,shuffle in zip(datasets,shuffles):
sampler = torch.utils.data.DistributedSampler(dataset, num_replicas=num_tasks, rank=global_rank, shuffle=shuffle)
samplers.append(sampler)
return samplers
def create_loader(datasets, samplers, batch_size, num_workers, is_trains, collate_fns):
loaders = []
for dataset,sampler,bs,n_worker,is_train,collate_fn in zip(datasets,samplers,batch_size,num_workers,is_trains,collate_fns):
if is_train:
shuffle = (sampler is None)
drop_last = True
else:
shuffle = False
drop_last = False
loader = DataLoader(
dataset,
batch_size=bs,
num_workers=n_worker,
pin_memory=True,
sampler=sampler,
shuffle=shuffle,
collate_fn=collate_fn,
drop_last=drop_last,
)
loaders.append(loader)
return loaders
import os
import json
from torch.utils.data import Dataset
from torchvision.datasets.utils import download_url
from PIL import Image
from data.utils import pre_caption
class coco_karpathy_train(Dataset):
def __init__(self, transform, image_root, ann_root, max_words=30, prompt=''):
'''
image_root (string): Root directory of images (e.g. coco/images/)
ann_root (string): directory to store the annotation file
'''
url = 'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_train.json'
filename = 'coco_karpathy_train.json'
download_url(url,ann_root)
self.annotation = json.load(open(os.path.join(ann_root,filename),'r'))
self.transform = transform
self.image_root = image_root
self.max_words = max_words
self.prompt = prompt
self.img_ids = {}
n = 0
for ann in self.annotation:
img_id = ann['image_id']
if img_id not in self.img_ids.keys():
self.img_ids[img_id] = n
n += 1
def __len__(self):
return len(self.annotation)
def __getitem__(self, index):
ann = self.annotation[index]
image_path = os.path.join(self.image_root,ann['image'])
image = Image.open(image_path).convert('RGB')
image = self.transform(image)
caption = self.prompt+pre_caption(ann['caption'], self.max_words)
return image, caption, self.img_ids[ann['image_id']]
class coco_karpathy_caption_eval(Dataset):
def __init__(self, transform, image_root, ann_root, split):
'''
image_root (string): Root directory of images (e.g. coco/images/)
ann_root (string): directory to store the annotation file
split (string): val or test
'''
urls = {'val':'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_val.json',
'test':'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_test.json'}
filenames = {'val':'coco_karpathy_val.json','test':'coco_karpathy_test.json'}
download_url(urls[split],ann_root)
self.annotation = json.load(open(os.path.join(ann_root,filenames[split]),'r'))
self.transform = transform
self.image_root = image_root
def __len__(self):
return len(self.annotation)
def __getitem__(self, index):
ann = self.annotation[index]
image_path = os.path.join(self.image_root,ann['image'])
image = Image.open(image_path).convert('RGB')
image = self.transform(image)
img_id = ann['image'].split('/')[-1].strip('.jpg').split('_')[-1]
return image, int(img_id)
class coco_karpathy_retrieval_eval(Dataset):
def __init__(self, transform, image_root, ann_root, split, max_words=30):
'''
image_root (string): Root directory of images (e.g. coco/images/)
ann_root (string): directory to store the annotation file
split (string): val or test
'''
urls = {'val':'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_val.json',
'test':'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_test.json'}
filenames = {'val':'coco_karpathy_val.json','test':'coco_karpathy_test.json'}
download_url(urls[split],ann_root)
self.annotation = json.load(open(os.path.join(ann_root,filenames[split]),'r'))
self.transform = transform
self.image_root = image_root
self.text = []
self.image = []
self.txt2img = {}
self.img2txt = {}
txt_id = 0
for img_id, ann in enumerate(self.annotation):
self.image.append(ann['image'])
self.img2txt[img_id] = []
for i, caption in enumerate(ann['caption']):
self.text.append(pre_caption(caption,max_words))
self.img2txt[img_id].append(txt_id)
self.txt2img[txt_id] = img_id
txt_id += 1
def __len__(self):
return len(self.annotation)
def __getitem__(self, index):
image_path = os.path.join(self.image_root, self.annotation[index]['image'])
image = Image.open(image_path).convert('RGB')
image = self.transform(image)
return image, index
\ No newline at end of file
import os
import json
from torch.utils.data import Dataset
from torchvision.datasets.utils import download_url
from PIL import Image
from data.utils import pre_caption
class flickr30k_train(Dataset):
def __init__(self, transform, image_root, ann_root, max_words=30, prompt=''):
'''
image_root (string): Root directory of images (e.g. flickr30k/)
ann_root (string): directory to store the annotation file
'''
url = 'https://storage.googleapis.com/sfr-vision-language-research/datasets/flickr30k_train.json'
filename = 'flickr30k_train.json'
download_url(url,ann_root)
self.annotation = json.load(open(os.path.join(ann_root,filename),'r'))
self.transform = transform
self.image_root = image_root
self.max_words = max_words
self.prompt = prompt
self.img_ids = {}
n = 0
for ann in self.annotation:
img_id = ann['image_id']
if img_id not in self.img_ids.keys():
self.img_ids[img_id] = n
n += 1
def __len__(self):
return len(self.annotation)
def __getitem__(self, index):
ann = self.annotation[index]
image_path = os.path.join(self.image_root,ann['image'])
image = Image.open(image_path).convert('RGB')
image = self.transform(image)
caption = self.prompt+pre_caption(ann['caption'], self.max_words)
return image, caption, self.img_ids[ann['image_id']]
class flickr30k_retrieval_eval(Dataset):
def __init__(self, transform, image_root, ann_root, split, max_words=30):
'''
image_root (string): Root directory of images (e.g. flickr30k/)
ann_root (string): directory to store the annotation file
split (string): val or test
'''
urls = {'val':'https://storage.googleapis.com/sfr-vision-language-research/datasets/flickr30k_val.json',
'test':'https://storage.googleapis.com/sfr-vision-language-research/datasets/flickr30k_test.json'}
filenames = {'val':'flickr30k_val.json','test':'flickr30k_test.json'}
download_url(urls[split],ann_root)
self.annotation = json.load(open(os.path.join(ann_root,filenames[split]),'r'))
self.transform = transform
self.image_root = image_root
self.text = []
self.image = []
self.txt2img = {}
self.img2txt = {}
txt_id = 0
for img_id, ann in enumerate(self.annotation):
self.image.append(ann['image'])
self.img2txt[img_id] = []
for i, caption in enumerate(ann['caption']):
self.text.append(pre_caption(caption,max_words))
self.img2txt[img_id].append(txt_id)
self.txt2img[txt_id] = img_id
txt_id += 1
def __len__(self):
return len(self.annotation)
def __getitem__(self, index):
image_path = os.path.join(self.image_root, self.annotation[index]['image'])
image = Image.open(image_path).convert('RGB')
image = self.transform(image)
return image, index
\ No newline at end of file
import os
import json
import random
from torch.utils.data import Dataset
from torchvision.datasets.utils import download_url
from PIL import Image
from data.utils import pre_caption
class nlvr_dataset(Dataset):
def __init__(self, transform, image_root, ann_root, split):
'''
image_root (string): Root directory of images
ann_root (string): directory to store the annotation file
split (string): train, val or test
'''
urls = {'train':'https://storage.googleapis.com/sfr-vision-language-research/datasets/nlvr_train.json',
'val':'https://storage.googleapis.com/sfr-vision-language-research/datasets/nlvr_dev.json',
'test':'https://storage.googleapis.com/sfr-vision-language-research/datasets/nlvr_test.json'}
filenames = {'train':'nlvr_train.json','val':'nlvr_dev.json','test':'nlvr_test.json'}
download_url(urls[split],ann_root)
self.annotation = json.load(open(os.path.join(ann_root,filenames[split]),'r'))
self.transform = transform
self.image_root = image_root
def __len__(self):
return len(self.annotation)
def __getitem__(self, index):
ann = self.annotation[index]
image0_path = os.path.join(self.image_root,ann['images'][0])
image0 = Image.open(image0_path).convert('RGB')
image0 = self.transform(image0)
image1_path = os.path.join(self.image_root,ann['images'][1])
image1 = Image.open(image1_path).convert('RGB')
image1 = self.transform(image1)
sentence = pre_caption(ann['sentence'], 40)
if ann['label']=='True':
label = 1
else:
label = 0
words = sentence.split(' ')
if 'left' not in words and 'right' not in words:
if random.random()<0.5:
return image0, image1, sentence, label
else:
return image1, image0, sentence, label
else:
if random.random()<0.5:
return image0, image1, sentence, label
else:
new_words = []
for word in words:
if word=='left':
new_words.append('right')
elif word=='right':
new_words.append('left')
else:
new_words.append(word)
sentence = ' '.join(new_words)
return image1, image0, sentence, label
\ No newline at end of file
import os
import json
from torch.utils.data import Dataset
from torchvision.datasets.utils import download_url
from PIL import Image
class nocaps_eval(Dataset):
def __init__(self, transform, image_root, ann_root, split):
urls = {'val':'https://storage.googleapis.com/sfr-vision-language-research/datasets/nocaps_val.json',
'test':'https://storage.googleapis.com/sfr-vision-language-research/datasets/nocaps_test.json'}
filenames = {'val':'nocaps_val.json','test':'nocaps_test.json'}
download_url(urls[split],ann_root)
self.annotation = json.load(open(os.path.join(ann_root,filenames[split]),'r'))
self.transform = transform
self.image_root = image_root
def __len__(self):
return len(self.annotation)
def __getitem__(self, index):
ann = self.annotation[index]
image_path = os.path.join(self.image_root,ann['image'])
image = Image.open(image_path).convert('RGB')
image = self.transform(image)
return image, int(ann['img_id'])
\ No newline at end of file
import json
import os
import random
from torch.utils.data import Dataset
from PIL import Image
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
Image.MAX_IMAGE_PIXELS = None
from data.utils import pre_caption
import os,glob
class pretrain_dataset(Dataset):
def __init__(self, ann_file, laion_path, transform):
self.ann_pretrain = []
for f in ann_file:
print('loading '+f)
ann = json.load(open(f,'r'))
self.ann_pretrain += ann
self.laion_path = laion_path
if self.laion_path:
self.laion_files = glob.glob(os.path.join(laion_path,'*.json'))
print('loading '+self.laion_files[0])
with open(self.laion_files[0],'r') as f:
self.ann_laion = json.load(f)
self.annotation = self.ann_pretrain + self.ann_laion
else:
self.annotation = self.ann_pretrain
self.transform = transform
def reload_laion(self, epoch):
n = epoch%len(self.laion_files)
print('loading '+self.laion_files[n])
with open(self.laion_files[n],'r') as f:
self.ann_laion = json.load(f)
self.annotation = self.ann_pretrain + self.ann_laion
def __len__(self):
return len(self.annotation)
def __getitem__(self, index):
ann = self.annotation[index]
image = Image.open(ann['image']).convert('RGB')
image = self.transform(image)
caption = pre_caption(ann['caption'],30)
return image, caption
\ No newline at end of file
import re
import json
import os
import torch
import torch.distributed as dist
import utils
def pre_caption(caption,max_words=50):
caption = re.sub(
r"([.!\"()*#:;~])",
' ',
caption.lower(),
)
caption = re.sub(
r"\s{2,}",
' ',
caption,
)
caption = caption.rstrip('\n')
caption = caption.strip(' ')
#truncate caption
caption_words = caption.split(' ')
if len(caption_words)>max_words:
caption = ' '.join(caption_words[:max_words])
return caption
def pre_question(question,max_ques_words=50):
question = re.sub(
r"([.!\"()*#:;~])",
'',
question.lower(),
)
question = question.rstrip(' ')
#truncate question
question_words = question.split(' ')
if len(question_words)>max_ques_words:
question = ' '.join(question_words[:max_ques_words])
return question
def save_result(result, result_dir, filename, remove_duplicate=''):
result_file = os.path.join(result_dir, '%s_rank%d.json'%(filename,utils.get_rank()))
final_result_file = os.path.join(result_dir, '%s.json'%filename)
json.dump(result,open(result_file,'w'))
dist.barrier()
if utils.is_main_process():
# combine results from all processes
result = []
for rank in range(utils.get_world_size()):
result_file = os.path.join(result_dir, '%s_rank%d.json'%(filename,rank))
res = json.load(open(result_file,'r'))
result += res
if remove_duplicate:
result_new = []
id_list = []
for res in result:
if res[remove_duplicate] not in id_list:
id_list.append(res[remove_duplicate])
result_new.append(res)
result = result_new
json.dump(result,open(final_result_file,'w'))
print('result file saved to %s'%final_result_file)
return final_result_file
from pycocotools.coco import COCO
from pycocoevalcap.eval import COCOEvalCap
from torchvision.datasets.utils import download_url
def coco_caption_eval(coco_gt_root, results_file, split):
urls = {'val':'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_val_gt.json',
'test':'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_test_gt.json'}
filenames = {'val':'coco_karpathy_val_gt.json','test':'coco_karpathy_test_gt.json'}
download_url(urls[split],coco_gt_root)
annotation_file = os.path.join(coco_gt_root,filenames[split])
# create coco object and coco_result object
coco = COCO(annotation_file)
coco_result = coco.loadRes(results_file)
# create coco_eval object by taking coco and coco_result
coco_eval = COCOEvalCap(coco, coco_result)
# evaluate on a subset of images by setting
# coco_eval.params['image_id'] = coco_result.getImgIds()
# please remove this line when evaluating the full validation set
# coco_eval.params['image_id'] = coco_result.getImgIds()
# evaluate results
# SPICE will take a few minutes the first time, but speeds up due to caching
coco_eval.evaluate()
# print output evaluation scores
for metric, score in coco_eval.eval.items():
print(f'{metric}: {score:.3f}')
return coco_eval
\ No newline at end of file
from torch.utils.data import Dataset
from torchvision.datasets.utils import download_url
from PIL import Image
import torch
import numpy as np
import random
import decord
from decord import VideoReader
import json
import os
from data.utils import pre_caption
decord.bridge.set_bridge("torch")
class ImageNorm(object):
"""Apply Normalization to Image Pixels on GPU
"""
def __init__(self, mean, std):
self.mean = torch.tensor(mean).view(1, 3, 1, 1)
self.std = torch.tensor(std).view(1, 3, 1, 1)
def __call__(self, img):
if torch.max(img) > 1 and self.mean.max() <= 1:
img.div_(255.)
return img.sub_(self.mean).div_(self.std)
def load_jsonl(filename):
with open(filename, "r") as f:
return [json.loads(l.strip("\n")) for l in f.readlines()]
class VideoDataset(Dataset):
def __init__(self, video_root, ann_root, num_frm=4, frm_sampling_strategy="rand", max_img_size=384, video_fmt='.mp4'):
'''
image_root (string): Root directory of video
ann_root (string): directory to store the annotation file
'''
url = 'https://storage.googleapis.com/sfr-vision-language-research/datasets/msrvtt_test.jsonl'
filename = 'msrvtt_test.jsonl'
download_url(url,ann_root)
self.annotation = load_jsonl(os.path.join(ann_root,filename))
self.num_frm = num_frm
self.frm_sampling_strategy = frm_sampling_strategy
self.max_img_size = max_img_size
self.video_root = video_root
self.video_fmt = video_fmt
self.img_norm = ImageNorm(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711))
self.text = [pre_caption(ann['caption'],40) for ann in self.annotation]
self.txt2video = [i for i in range(len(self.annotation))]
self.video2txt = self.txt2video
def __len__(self):
return len(self.annotation)
def __getitem__(self, index):
ann = self.annotation[index]
video_path = os.path.join(self.video_root, ann['clip_name'] + self.video_fmt)
vid_frm_array = self._load_video_from_path_decord(video_path, height=self.max_img_size, width=self.max_img_size)
video = self.img_norm(vid_frm_array.float())
return video, ann['clip_name']
def _load_video_from_path_decord(self, video_path, height=None, width=None, start_time=None, end_time=None, fps=-1):
try:
if not height or not width:
vr = VideoReader(video_path)
else:
vr = VideoReader(video_path, width=width, height=height)
vlen = len(vr)
if start_time or end_time:
assert fps > 0, 'must provide video fps if specifying start and end time.'
start_idx = min(int(start_time * fps), vlen)
end_idx = min(int(end_time * fps), vlen)
else:
start_idx, end_idx = 0, vlen
if self.frm_sampling_strategy == 'uniform':
frame_indices = np.arange(start_idx, end_idx, vlen / self.num_frm, dtype=int)
elif self.frm_sampling_strategy == 'rand':
frame_indices = sorted(random.sample(range(vlen), self.num_frm))
elif self.frm_sampling_strategy == 'headtail':
frame_indices_head = sorted(random.sample(range(vlen // 2), self.num_frm // 2))
frame_indices_tail = sorted(random.sample(range(vlen // 2, vlen), self.num_frm // 2))
frame_indices = frame_indices_head + frame_indices_tail
else:
raise NotImplementedError('Invalid sampling strategy {} '.format(self.frm_sampling_strategy))
raw_sample_frms = vr.get_batch(frame_indices)
except Exception as e:
return None
raw_sample_frms = raw_sample_frms.permute(0, 3, 1, 2)
return raw_sample_frms
import os
import json
import random
from PIL import Image
import torch
from torch.utils.data import Dataset
from data.utils import pre_question
from torchvision.datasets.utils import download_url
class vqa_dataset(Dataset):
def __init__(self, transform, ann_root, vqa_root, vg_root, train_files=[], split="train"):
self.split = split
self.transform = transform
self.vqa_root = vqa_root
self.vg_root = vg_root
if split=='train':
urls = {'vqa_train':'https://storage.googleapis.com/sfr-vision-language-research/datasets/vqa_train.json',
'vqa_val':'https://storage.googleapis.com/sfr-vision-language-research/datasets/vqa_val.json',
'vg_qa':'https://storage.googleapis.com/sfr-vision-language-research/datasets/vg_qa.json'}
self.annotation = []
for f in train_files:
download_url(urls[f],ann_root)
self.annotation += json.load(open(os.path.join(ann_root,'%s.json'%f),'r'))
else:
download_url('https://storage.googleapis.com/sfr-vision-language-research/datasets/vqa_test.json',ann_root)
self.annotation = json.load(open(os.path.join(ann_root,'vqa_test.json'),'r'))
download_url('https://storage.googleapis.com/sfr-vision-language-research/datasets/answer_list.json',ann_root)
self.answer_list = json.load(open(os.path.join(ann_root,'answer_list.json'),'r'))
def __len__(self):
return len(self.annotation)
def __getitem__(self, index):
ann = self.annotation[index]
if ann['dataset']=='vqa':
image_path = os.path.join(self.vqa_root,ann['image'])
elif ann['dataset']=='vg':
image_path = os.path.join(self.vg_root,ann['image'])
image = Image.open(image_path).convert('RGB')
image = self.transform(image)
if self.split == 'test':
question = pre_question(ann['question'])
question_id = ann['question_id']
return image, question, question_id
elif self.split=='train':
question = pre_question(ann['question'])
if ann['dataset']=='vqa':
answer_weight = {}
for answer in ann['answer']:
if answer in answer_weight.keys():
answer_weight[answer] += 1/len(ann['answer'])
else:
answer_weight[answer] = 1/len(ann['answer'])
answers = list(answer_weight.keys())
weights = list(answer_weight.values())
elif ann['dataset']=='vg':
answers = [ann['answer']]
weights = [0.2]
return image, question, answers, weights
def vqa_collate_fn(batch):
image_list, question_list, answer_list, weight_list, n = [], [], [], [], []
for image, question, answer, weights in batch:
image_list.append(image)
question_list.append(question)
weight_list += weights
answer_list += answer
n.append(len(answer))
return torch.stack(image_list,dim=0), question_list, answer_list, torch.Tensor(weight_list), n
\ No newline at end of file
This diff is collapsed.
'''
* Copyright (c) 2022, salesforce.com, inc.
* All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
* For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
* By Junnan Li
'''
import argparse
import os
import ruamel_yaml as yaml
import numpy as np
import random
import time
import datetime
import json
from pathlib import Path
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
import torch.distributed as dist
from torch.utils.data import DataLoader
from models.blip import blip_decoder
import utils
from data import create_dataset, create_sampler, create_loader
from data.utils import save_result
@torch.no_grad()
def evaluate(model, data_loader, device, config):
# evaluate
model.eval()
metric_logger = utils.MetricLogger(delimiter=" ")
header = 'Evaluation:'
print_freq = 10
result = []
for image, image_id in metric_logger.log_every(data_loader, print_freq, header):
image = image.to(device)
captions = model.generate(image, sample=False, num_beams=config['num_beams'], max_length=config['max_length'],
min_length=config['min_length'], repetition_penalty=1.1)
for caption, img_id in zip(captions, image_id):
result.append({"image_id": img_id.item(), "caption": caption})
return result
def main(args, config):
utils.init_distributed_mode(args)
device = torch.device(args.device)
# fix the seed for reproducibility
seed = args.seed + utils.get_rank()
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
cudnn.benchmark = True
#### Dataset ####
print("Creating captioning dataset")
val_dataset, test_dataset = create_dataset('nocaps', config)
if args.distributed:
num_tasks = utils.get_world_size()
global_rank = utils.get_rank()
samplers = create_sampler([val_dataset,test_dataset], [False,False], num_tasks, global_rank)
else:
samplers = [None,None]
val_loader, test_loader = create_loader([val_dataset, test_dataset],samplers,
batch_size=[config['batch_size']]*2,num_workers=[4,4],
is_trains=[False, False], collate_fns=[None,None])
#### Model ####
print("Creating model")
model = blip_decoder(pretrained=config['pretrained'], image_size=config['image_size'], vit=config['vit'],
prompt=config['prompt'])
model = model.to(device)
model_without_ddp = model
if args.distributed:
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
model_without_ddp = model.module
val_result = evaluate(model_without_ddp, val_loader, device, config)
val_result_file = save_result(val_result, args.result_dir, 'val', remove_duplicate='image_id')
test_result = evaluate(model_without_ddp, test_loader, device, config)
test_result_file = save_result(test_result, args.result_dir, 'test', remove_duplicate='image_id')
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--config', default='./configs/nocaps.yaml')
parser.add_argument('--output_dir', default='output/NoCaps')
parser.add_argument('--device', default='cuda')
parser.add_argument('--seed', default=42, type=int)
parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes')
parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
parser.add_argument('--distributed', default=True, type=bool)
args = parser.parse_args()
config = yaml.load(open(args.config, 'r'), Loader=yaml.Loader)
args.result_dir = os.path.join(args.output_dir, 'result')
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
Path(args.result_dir).mkdir(parents=True, exist_ok=True)
yaml.dump(config, open(os.path.join(args.output_dir, 'config.yaml'), 'w'))
main(args, config)
\ No newline at end of file
'''
* Copyright (c) 2022, salesforce.com, inc.
* All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
* For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
* By Junnan Li
'''
import argparse
import os
import ruamel_yaml as yaml
import numpy as np
import random
import time
import datetime
import json
from pathlib import Path
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
import torch.distributed as dist
from torch.utils.data import DataLoader
from models.blip_retrieval import blip_retrieval
import utils
from data.video_dataset import VideoDataset
@torch.no_grad()
def evaluation(model, data_loader, tokenizer, device, config):
# test
model.eval()
metric_logger = utils.MetricLogger(delimiter=" ")
header = 'Evaluation:'
print('Computing features for evaluation...')
start_time = time.time()
texts = data_loader.dataset.text
num_text = len(texts)
text_bs = 256
text_ids = []
text_embeds = []
text_atts = []
for i in range(0, num_text, text_bs):
text = texts[i: min(num_text, i+text_bs)]
text_input = tokenizer(text, padding='max_length', truncation=True, max_length=35, return_tensors="pt").to(device)
text_output = model.text_encoder(text_input.input_ids, attention_mask = text_input.attention_mask, mode='text')
text_embed = F.normalize(model.text_proj(text_output.last_hidden_state[:,0,:]))
text_embeds.append(text_embed)
text_ids.append(text_input.input_ids)
text_atts.append(text_input.attention_mask)
text_embeds = torch.cat(text_embeds,dim=0)
text_ids = torch.cat(text_ids,dim=0)
text_atts = torch.cat(text_atts,dim=0)
text_ids[:,0] = tokenizer.additional_special_tokens_ids[0]
video_feats = []
video_embeds = []
for video, video_id in data_loader:
B,N,C,W,H = video.size()
video = video.view(-1,C,W,H)
video = video.to(device,non_blocking=True)
video_feat = model.visual_encoder(video)
video_embed = model.vision_proj(video_feat[:,0,:])
video_embed = video_embed.view(B,N,-1).mean(dim=1)
video_embed = F.normalize(video_embed,dim=-1)
video_feat = video_feat.view(B,-1,video_feat.shape[-1])
video_feats.append(video_feat.cpu())
video_embeds.append(video_embed)
video_feats = torch.cat(video_feats,dim=0)
video_embeds = torch.cat(video_embeds,dim=0)
sims_matrix = video_embeds @ text_embeds.t()
score_matrix_v2t = torch.full((len(texts),len(texts)),-100.0).to(device)
num_tasks = utils.get_world_size()
rank = utils.get_rank()
step = sims_matrix.size(0)//num_tasks + 1
start = rank*step
end = min(sims_matrix.size(0),start+step)
for i,sims in enumerate(metric_logger.log_every(sims_matrix[start:end], 50, header)):
topk_sim, topk_idx = sims.topk(k=config['k_test'], dim=0)
encoder_output = video_feats[start+i].repeat(config['k_test'],1,1).to(device,non_blocking=True)
encoder_att = torch.ones(encoder_output.size()[:-1],dtype=torch.long).to(device,non_blocking=True)
output = model.text_encoder(text_ids[topk_idx],
attention_mask = text_atts[topk_idx],
encoder_hidden_states = encoder_output,
encoder_attention_mask = encoder_att,
return_dict = True,
)
score = model.itm_head(output.last_hidden_state[:,0,:])[:,1]
score_matrix_v2t[start+i,topk_idx] = score + topk_sim
sims_matrix = sims_matrix.t()
score_matrix_t2v = torch.full((len(texts),len(texts)),-100.0).to(device)
step = sims_matrix.size(0)//num_tasks + 1
start = rank*step
end = min(sims_matrix.size(0),start+step)
for i,sims in enumerate(metric_logger.log_every(sims_matrix[start:end], 50, header)):
topk_sim, topk_idx = sims.topk(k=config['k_test'], dim=0)
encoder_output = video_feats[topk_idx].to(device,non_blocking=True)
encoder_att = torch.ones(encoder_output.size()[:-1],dtype=torch.long).to(device,non_blocking=True)
output = model.text_encoder(text_ids[start+i].repeat(config['k_test'],1),
attention_mask = text_atts[start+i].repeat(config['k_test'],1),
encoder_hidden_states = encoder_output,
encoder_attention_mask = encoder_att,
return_dict = True,
)
score = model.itm_head(output.last_hidden_state[:,0,:])[:,1]
score_matrix_t2v[start+i,topk_idx] = score + topk_sim
if args.distributed:
dist.barrier()
torch.distributed.all_reduce(score_matrix_v2t, op=torch.distributed.ReduceOp.SUM)
torch.distributed.all_reduce(score_matrix_t2v, op=torch.distributed.ReduceOp.SUM)
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print('Evaluation time {}'.format(total_time_str))
return score_matrix_v2t.cpu().numpy(), score_matrix_t2v.cpu().numpy()
@torch.no_grad()
def itm_eval(scores_v2t, scores_t2v, txt2vmg, vid2txt):
#Video->Text
ranks = np.zeros(scores_v2t.shape[0])
for index,score in enumerate(scores_v2t):
inds = np.argsort(score)[::-1]
ranks[index] = np.where(inds == vid2txt[index])[0][0]
# Compute metrics
tr1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks)
tr5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks)
tr10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks)
#Text->Video
ranks = np.zeros(scores_t2v.shape[0])
for index,score in enumerate(scores_t2v):
inds = np.argsort(score)[::-1]
ranks[index] = np.where(inds == txt2vmg[index])[0][0]
mdR = np.median(ranks+1)
# Compute metrics
vr1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks)
vr5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks)
vr10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks)
tr_mean = (tr1 + tr5 + tr10) / 3
vr_mean = (vr1 + vr5 + vr10) / 3
r_mean = (tr_mean + vr_mean) / 2
eval_result = {'txt_r1': tr1,
'txt_r5': tr5,
'txt_r10': tr10,
'txt_r_mean': tr_mean,
'vid_r1': vr1,
'vid_r5': vr5,
'vid_r10': vr10,
'vid_r_mean': vr_mean,
'vid_mdR': mdR,
'r_mean': r_mean}
return eval_result
def main(args, config):
utils.init_distributed_mode(args)
device = torch.device(args.device)
# fix the seed for reproducibility
seed = args.seed + utils.get_rank()
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
cudnn.benchmark = True
#### Dataset ####
print("Creating retrieval dataset")
test_dataset = VideoDataset(config['video_root'],config['ann_root'],num_frm=config['num_frm_test'],
max_img_size=config['image_size'], frm_sampling_strategy='uniform')
test_loader = DataLoader(
test_dataset,
batch_size=config['batch_size'],
num_workers=4,
pin_memory=True,
drop_last=False,
shuffle=False,
)
#### Model ####
print("Creating model")
model = blip_retrieval(pretrained=config['pretrained'], image_size=config['image_size'], vit=config['vit'])
model = model.to(device)
model_without_ddp = model
if args.distributed:
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
model_without_ddp = model.module
score_v2t, score_t2v, = evaluation(model_without_ddp, test_loader, model_without_ddp.tokenizer, device, config)
if utils.is_main_process():
test_result = itm_eval(score_v2t, score_t2v, test_loader.dataset.txt2video, test_loader.dataset.video2txt)
print(test_result)
log_stats = {**{f'{k}': v for k, v in test_result.items()},}
with open(os.path.join(args.output_dir, "test_result.txt"),"a") as f:
f.write(json.dumps(log_stats) + "\n")
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--config', default='./configs/retrieval_msrvtt.yaml')
parser.add_argument('--output_dir', default='output/Retrieval_msrvtt')
parser.add_argument('--device', default='cuda')
parser.add_argument('--seed', default=42, type=int)
parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes')
parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
parser.add_argument('--distributed', default=True, type=bool)
args = parser.parse_args()
config = yaml.load(open(args.config, 'r'), Loader=yaml.Loader)
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
yaml.dump(config, open(os.path.join(args.output_dir, 'config.yaml'), 'w'))
main(args, config)
\ No newline at end of file
This diff is collapsed.
from models.med import BertConfig, BertModel
from transformers import BertTokenizer
import torch
from torch import nn
import torch.nn.functional as F
from models.blip import create_vit, init_tokenizer, load_checkpoint
class BLIP_ITM(nn.Module):
def __init__(self,
med_config = 'configs/med_config.json',
image_size = 384,
vit = 'base',
vit_grad_ckpt = False,
vit_ckpt_layer = 0,
embed_dim = 256,
):
"""
Args:
med_config (str): path for the mixture of encoder-decoder model's configuration file
image_size (int): input image size
vit (str): model size of vision transformer
"""
super().__init__()
self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer)
self.tokenizer = init_tokenizer()
med_config = BertConfig.from_json_file(med_config)
med_config.encoder_width = vision_width
self.text_encoder = BertModel(config=med_config, add_pooling_layer=False)
text_width = self.text_encoder.config.hidden_size
self.vision_proj = nn.Linear(vision_width, embed_dim)
self.text_proj = nn.Linear(text_width, embed_dim)
self.itm_head = nn.Linear(text_width, 2)
def forward(self, image, caption, match_head='itm'):
image_embeds = self.visual_encoder(image)
image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
text = self.tokenizer(caption, padding='max_length', truncation=True, max_length=35,
return_tensors="pt").to(image.device)
if match_head=='itm':
output = self.text_encoder(text.input_ids,
attention_mask = text.attention_mask,
encoder_hidden_states = image_embeds,
encoder_attention_mask = image_atts,
return_dict = True,
)
itm_output = self.itm_head(output.last_hidden_state[:,0,:])
return itm_output
elif match_head=='itc':
text_output = self.text_encoder(text.input_ids, attention_mask = text.attention_mask,
return_dict = True, mode = 'text')
image_feat = F.normalize(self.vision_proj(image_embeds[:,0,:]),dim=-1)
text_feat = F.normalize(self.text_proj(text_output.last_hidden_state[:,0,:]),dim=-1)
sim = image_feat @ text_feat.t()
return sim
def blip_itm(pretrained='',**kwargs):
model = BLIP_ITM(**kwargs)
if pretrained:
model,msg = load_checkpoint(model,pretrained)
assert(len(msg.missing_keys)==0)
return model
\ No newline at end of file
from models.med import BertConfig
from models.nlvr_encoder import BertModel
from models.vit import interpolate_pos_embed
from models.blip import create_vit, init_tokenizer, is_url
from timm.models.hub import download_cached_file
import torch
from torch import nn
import torch.nn.functional as F
from transformers import BertTokenizer
import numpy as np
class BLIP_NLVR(nn.Module):
def __init__(self,
med_config = 'configs/med_config.json',
image_size = 480,
vit = 'base',
vit_grad_ckpt = False,
vit_ckpt_layer = 0,
):
"""
Args:
med_config (str): path for the mixture of encoder-decoder model's configuration file
image_size (int): input image size
vit (str): model size of vision transformer
"""
super().__init__()
self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer, drop_path_rate=0.1)
self.tokenizer = init_tokenizer()
med_config = BertConfig.from_json_file(med_config)
med_config.encoder_width = vision_width
self.text_encoder = BertModel(config=med_config, add_pooling_layer=False)
self.cls_head = nn.Sequential(
nn.Linear(self.text_encoder.config.hidden_size, self.text_encoder.config.hidden_size),
nn.ReLU(),
nn.Linear(self.text_encoder.config.hidden_size, 2)
)
def forward(self, image, text, targets, train=True):
image_embeds = self.visual_encoder(image)
image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
image0_embeds, image1_embeds = torch.split(image_embeds,targets.size(0))
text = self.tokenizer(text, padding='longest', return_tensors="pt").to(image.device)
text.input_ids[:,0] = self.tokenizer.enc_token_id
output = self.text_encoder(text.input_ids,
attention_mask = text.attention_mask,
encoder_hidden_states = [image0_embeds,image1_embeds],
encoder_attention_mask = [image_atts[:image0_embeds.size(0)],
image_atts[image0_embeds.size(0):]],
return_dict = True,
)
hidden_state = output.last_hidden_state[:,0,:]
prediction = self.cls_head(hidden_state)
if train:
loss = F.cross_entropy(prediction, targets)
return loss
else:
return prediction
def blip_nlvr(pretrained='',**kwargs):
model = BLIP_NLVR(**kwargs)
if pretrained:
model,msg = load_checkpoint(model,pretrained)
print("missing keys:")
print(msg.missing_keys)
return model
def load_checkpoint(model,url_or_filename):
if is_url(url_or_filename):
cached_file = download_cached_file(url_or_filename, check_hash=False, progress=True)
checkpoint = torch.load(cached_file, map_location='cpu')
elif os.path.isfile(url_or_filename):
checkpoint = torch.load(url_or_filename, map_location='cpu')
else:
raise RuntimeError('checkpoint url or path is invalid')
state_dict = checkpoint['model']
state_dict['visual_encoder.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder.pos_embed'],model.visual_encoder)
for key in list(state_dict.keys()):
if 'crossattention.self.' in key:
new_key0 = key.replace('self','self0')
new_key1 = key.replace('self','self1')
state_dict[new_key0] = state_dict[key]
state_dict[new_key1] = state_dict[key]
elif 'crossattention.output.dense.' in key:
new_key0 = key.replace('dense','dense0')
new_key1 = key.replace('dense','dense1')
state_dict[new_key0] = state_dict[key]
state_dict[new_key1] = state_dict[key]
msg = model.load_state_dict(state_dict,strict=False)
print('load checkpoint from %s'%url_or_filename)
return model,msg
\ No newline at end of file
from models.med import BertConfig, BertModel, BertLMHeadModel
from models.blip import create_vit, init_tokenizer, load_checkpoint
import torch
from torch import nn
import torch.nn.functional as F
from transformers import BertTokenizer
import numpy as np
class BLIP_VQA(nn.Module):
def __init__(self,
med_config = 'configs/med_config.json',
image_size = 480,
vit = 'base',
vit_grad_ckpt = False,
vit_ckpt_layer = 0,
):
"""
Args:
med_config (str): path for the mixture of encoder-decoder model's configuration file
image_size (int): input image size
vit (str): model size of vision transformer
"""
super().__init__()
self.visual_encoder, vision_width = create_vit(vit, image_size, vit_grad_ckpt, vit_ckpt_layer, drop_path_rate=0.1)
self.tokenizer = init_tokenizer()
encoder_config = BertConfig.from_json_file(med_config)
encoder_config.encoder_width = vision_width
self.text_encoder = BertModel(config=encoder_config, add_pooling_layer=False)
decoder_config = BertConfig.from_json_file(med_config)
self.text_decoder = BertLMHeadModel(config=decoder_config)
def forward(self, image, question, answer=None, n=None, weights=None, train=True, inference='rank', k_test=128):
image_embeds = self.visual_encoder(image)
image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
question = self.tokenizer(question, padding='longest', truncation=True, max_length=35,
return_tensors="pt").to(image.device)
question.input_ids[:,0] = self.tokenizer.enc_token_id
if train:
'''
n: number of answers for each question
weights: weight for each answer
'''
answer = self.tokenizer(answer, padding='longest', return_tensors="pt").to(image.device)
answer.input_ids[:,0] = self.tokenizer.bos_token_id
answer_targets = answer.input_ids.masked_fill(answer.input_ids == self.tokenizer.pad_token_id, -100)
question_output = self.text_encoder(question.input_ids,
attention_mask = question.attention_mask,
encoder_hidden_states = image_embeds,
encoder_attention_mask = image_atts,
return_dict = True)
question_states = []
question_atts = []
for b, n in enumerate(n):
question_states += [question_output.last_hidden_state[b]]*n
question_atts += [question.attention_mask[b]]*n
question_states = torch.stack(question_states,0)
question_atts = torch.stack(question_atts,0)
answer_output = self.text_decoder(answer.input_ids,
attention_mask = answer.attention_mask,
encoder_hidden_states = question_states,
encoder_attention_mask = question_atts,
labels = answer_targets,
return_dict = True,
reduction = 'none',
)
loss = weights * answer_output.loss
loss = loss.sum()/image.size(0)
return loss
else:
question_output = self.text_encoder(question.input_ids,
attention_mask = question.attention_mask,
encoder_hidden_states = image_embeds,
encoder_attention_mask = image_atts,
return_dict = True)
if inference=='generate':
num_beams = 3
question_states = question_output.last_hidden_state.repeat_interleave(num_beams,dim=0)
question_atts = torch.ones(question_states.size()[:-1],dtype=torch.long).to(question_states.device)
model_kwargs = {"encoder_hidden_states": question_states, "encoder_attention_mask":question_atts}
bos_ids = torch.full((image.size(0),1),fill_value=self.tokenizer.bos_token_id,device=image.device)
outputs = self.text_decoder.generate(input_ids=bos_ids,
max_length=10,
min_length=1,
num_beams=num_beams,
eos_token_id=self.tokenizer.sep_token_id,
pad_token_id=self.tokenizer.pad_token_id,
**model_kwargs)
answers = []
for output in outputs:
answer = self.tokenizer.decode(output, skip_special_tokens=True)
answers.append(answer)
return answers
elif inference=='rank':
max_ids = self.rank_answer(question_output.last_hidden_state, question.attention_mask,
answer.input_ids, answer.attention_mask, k_test)
return max_ids
def rank_answer(self, question_states, question_atts, answer_ids, answer_atts, k):
num_ques = question_states.size(0)
start_ids = answer_ids[0,0].repeat(num_ques,1) # bos token
start_output = self.text_decoder(start_ids,
encoder_hidden_states = question_states,
encoder_attention_mask = question_atts,
return_dict = True,
reduction = 'none')
logits = start_output.logits[:,0,:] # first token's logit
# topk_probs: top-k probability
# topk_ids: [num_question, k]
answer_first_token = answer_ids[:,1]
prob_first_token = F.softmax(logits,dim=1).index_select(dim=1, index=answer_first_token)
topk_probs, topk_ids = prob_first_token.topk(k,dim=1)
# answer input: [num_question*k, answer_len]
input_ids = []
input_atts = []
for b, topk_id in enumerate(topk_ids):
input_ids.append(answer_ids.index_select(dim=0, index=topk_id))
input_atts.append(answer_atts.index_select(dim=0, index=topk_id))
input_ids = torch.cat(input_ids,dim=0)
input_atts = torch.cat(input_atts,dim=0)
targets_ids = input_ids.masked_fill(input_ids == self.tokenizer.pad_token_id, -100)
# repeat encoder's output for top-k answers
question_states = tile(question_states, 0, k)
question_atts = tile(question_atts, 0, k)
output = self.text_decoder(input_ids,
attention_mask = input_atts,
encoder_hidden_states = question_states,
encoder_attention_mask = question_atts,
labels = targets_ids,
return_dict = True,
reduction = 'none')
log_probs_sum = -output.loss
log_probs_sum = log_probs_sum.view(num_ques,k)
max_topk_ids = log_probs_sum.argmax(dim=1)
max_ids = topk_ids[max_topk_ids>=0,max_topk_ids]
return max_ids
def blip_vqa(pretrained='',**kwargs):
model = BLIP_VQA(**kwargs)
if pretrained:
model,msg = load_checkpoint(model,pretrained)
# assert(len(msg.missing_keys)==0)
return model
def tile(x, dim, n_tile):
init_dim = x.size(dim)
repeat_idx = [1] * x.dim()
repeat_idx[dim] = n_tile
x = x.repeat(*(repeat_idx))
order_index = torch.LongTensor(np.concatenate([init_dim * np.arange(n_tile) + i for i in range(init_dim)]))
return torch.index_select(x, dim, order_index.to(x.device))
\ No newline at end of file
This diff is collapsed.
This diff is collapsed.
"""
Download the weights in ./checkpoints beforehand for fast inference
wget https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model*_base_caption.pth
wget https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model*_vqa.pth
wget https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_retrieval_coco.pth
"""
from pathlib import Path
from PIL import Image
import torch
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode
import cog
from models.blip import blip_decoder
from models.blip_vqa import blip_vqa
from models.blip_itm import blip_itm
class Predictor(cog.Predictor):
def setup(self):
self.device = "cuda:0"
self.models = {
'image_captioning': blip_decoder(pretrained='checkpoints/model*_base_caption.pth',
image_size=384, vit='base'),
'visual_question_answering': blip_vqa(pretrained='checkpoints/model*_vqa.pth',
image_size=480, vit='base'),
'image_text_matching': blip_itm(pretrained='checkpoints/model_base_retrieval_coco.pth',
image_size=384, vit='base')
}
@cog.input(
"image",
type=Path,
help="input image",
)
@cog.input(
"task",
type=str,
default='image_captioning',
options=['image_captioning', 'visual_question_answering', 'image_text_matching'],
help="Choose a task.",
)
@cog.input(
"question",
type=str,
default=None,
help="Type question for the input image for visual question answering task.",
)
@cog.input(
"caption",
type=str,
default=None,
help="Type caption for the input image for image text matching task.",
)
def predict(self, image, task, question, caption):
if task == 'visual_question_answering':
assert question is not None, 'Please type a question for visual question answering task.'
if task == 'image_text_matching':
assert caption is not None, 'Please type a caption for mage text matching task.'
im = load_image(image, image_size=480 if task == 'visual_question_answering' else 384, device=self.device)
model = self.models[task]
model.eval()
model = model.to(self.device)
if task == 'image_captioning':
with torch.no_grad():
caption = model.generate(im, sample=False, num_beams=3, max_length=20, min_length=5)
return 'Caption: ' + caption[0]
if task == 'visual_question_answering':
with torch.no_grad():
answer = model(im, question, train=False, inference='generate')
return 'Answer: ' + answer[0]
# image_text_matching
itm_output = model(im, caption, match_head='itm')
itm_score = torch.nn.functional.softmax(itm_output, dim=1)[:, 1]
itc_score = model(im, caption, match_head='itc')
return f'The image and text is matched with a probability of {itm_score.item():.4f}.\n' \
f'The image feature and text feature has a cosine similarity of {itc_score.item():.4f}.'
def load_image(image, image_size, device):
raw_image = Image.open(str(image)).convert('RGB')
w, h = raw_image.size
transform = transforms.Compose([
transforms.Resize((image_size, image_size), interpolation=InterpolationMode.BICUBIC),
transforms.ToTensor(),
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
])
image = transform(raw_image).unsqueeze(0).to(device)
return image
This diff is collapsed.
timm==0.4.12
transformers==4.15.0
fairscale==0.4.4
pycocoevalcap
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
.vscode
# ignored files
version.py
# ignored files with suffix
*.html
# *.png
# *.jpeg
# *.jpg
*.pt
*.gif
*.pth
*.dat
*.zip
# template
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
.hypothesis/
.pytest_cache/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
target/
# Jupyter Notebook
.ipynb_checkpoints
# pyenv
.python-version
# celery beat schedule file
celerybeat-schedule
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
# project
results/
dlib/
*_old*
This diff is collapsed.
from .deform_conv import (DeformConv, DeformConvPack, ModulatedDeformConv, ModulatedDeformConvPack, deform_conv,
modulated_deform_conv)
__all__ = [
'DeformConv', 'DeformConvPack', 'ModulatedDeformConv', 'ModulatedDeformConvPack', 'deform_conv',
'modulated_deform_conv'
]
from .fused_act import FusedLeakyReLU, fused_leaky_relu
__all__ = ['FusedLeakyReLU', 'fused_leaky_relu']
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment