条件付きC RNN GANで音楽を生成
前置き
昨年のGANの動向を見ると,半教師あり学習・条件付きでの生成がトレンドだったように思います. そこでC-RNN-GANのモデルを元にを条件付きで音楽を生成したいと思います.
条件付きGAN
条件付きのGANについてはこちらの論文を参照します.
モデル図は下記の通りのようになっています.
生成器には,ランダムな分布zとラベルyを元に生成しています.
一方識別器は,生成されてもの/学習データxとラベルyを入力しxであるかの真偽を出力しています.
最近ではこれとCycleGANなどを組み合わせたStarGANなどがあります.
C-RNN-GAN
C-RNN-GANは以下のようなモデル図になっています.
基本的に生成器も識別器もRNNになった感じです.
ただ識別器側はBiRNNになっています.
また損失関数も通常のGANと変わりはありません.
Conditional C-RNN-GAN
モデル形状はC RNN GANで同じで,両方のモデルにラベルを加えて入力しています. 損失はWGANを使用しました.
生成結果はこちらにおいておきます.
Conditional_C_RNN_GAN/generated_mid at wgan · TrsNium/Conditional_C_RNN_GAN · GitHub
聞いてみればわかるのですが,mode collapseがおきています.
これを回避するためにWGANなどを入れて見たのですが上手く学習をすることができなかったようです.
コードは以下に置いておきます.
データセット
freemidiからデータを集めました.
一応カテゴライズされているので,そのカテゴリーに沿ってラベルを作成しデータセットを作成しました.
スクレイピングのコードは下記にあります.
実行するときは,chrome driverが必要なのでご用意してください.
from bs4 import BeautifulSoup from urllib.request import urlopen from urllib.request import urlretrieve import re import os from selenium import webdriver from selenium.webdriver.common.keys import Keys import time content_url = "https://freemidi.org/" html_doc = urlopen(content_url+"genre").read() sp = BeautifulSoup(html_doc) genres = sp.find_all("div", {"class":"genre-big-ones"}) genres_href = [[tag["href"] for tag in genre.find_all("a")] for genre in genres] ''' [['genre-rock', 'genre-pop', 'genre-hip-hop-rap', 'genre-rnb-soul'], ['genre-classical', 'genre-country', 'genre-jazz', 'genre-blues'], ['genre-dance-eletric', 'genre-folk', 'genre-punk', 'genre-newage'], ['genre-reggae-ska', 'genre-metal', 'genre-disco', 'genre-bluegrass']] ''' genre_artist = {} #お好きなジャンルをお選びください selected_genre = [] for hrefs in genres_href: for href in hrefs: if not href in selected_genre: continue content = {} url = content_url + href html_doc = BeautifulSoup(urlopen(url).read(), "lxml") artist_hrefs = html_doc.find_all("div", {"class": "genre-link-text"}) print(href) for artist_href in artist_hrefs: artist_href_ = artist_href.a.get("href") a_html_doc = BeautifulSoup(urlopen(content_url+artist_href_).read()) content[artist_href.a.string] = {re.sub("\r\n\s{2,}", "", artist_href.a.string):a.get("href") for a in a_html_doc.find_all("a", {"itemprop":"url"})[1:]} genre_artist[href] = content cwd = os.getcwd() for key, item in genre_artist.items(): if not os.path.exists(key): os.mkdir(key) chromeOptions = webdriver.ChromeOptions() prefs = {"download.default_directory" : cwd+"/"+key+"/"} chromeOptions.add_experimental_option("prefs",prefs) for artist_n, song_dict in item.items(): for song_n, song_href in song_dict.items(): try: browser = webdriver.Chrome(executable_path="chromedriver", chrome_options=chromeOptions) browser.get(content_url+song_href) browser.find_element_by_link_text('Download MIDI').click() time.sleep(5) browser.quit() except: continue
おわり
想定していた結果が得れませんでしたが,考え直すとデータ数が少なかったのかもしれません. データ数を増やしたり損失関数を工夫してもう一度挑戦して行きたいと思います.